Spaces:
Runtime error
Runtime error
Commit
·
91fb4ef
0
Parent(s):
initial commit log 🪵🦫
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +8 -0
- Dockerfile +44 -0
- README.md +96 -0
- accelerate_configs/compiled_1.yaml +22 -0
- accelerate_configs/deepspeed.yaml +23 -0
- accelerate_configs/uncompiled_1.yaml +17 -0
- accelerate_configs/uncompiled_2.yaml +17 -0
- accelerate_configs/uncompiled_8.yaml +17 -0
- app.py +1270 -0
- captioning_service.py +534 -0
- config.py +303 -0
- finetrainers/__init__.py +2 -0
- finetrainers/args.py +1191 -0
- finetrainers/constants.py +80 -0
- finetrainers/dataset.py +467 -0
- finetrainers/hooks/__init__.py +1 -0
- finetrainers/hooks/hooks.py +176 -0
- finetrainers/hooks/layerwise_upcasting.py +140 -0
- finetrainers/models/__init__.py +33 -0
- finetrainers/models/cogvideox/__init__.py +2 -0
- finetrainers/models/cogvideox/full_finetune.py +32 -0
- finetrainers/models/cogvideox/lora.py +334 -0
- finetrainers/models/cogvideox/utils.py +51 -0
- finetrainers/models/hunyuan_video/__init__.py +2 -0
- finetrainers/models/hunyuan_video/full_finetune.py +30 -0
- finetrainers/models/hunyuan_video/lora.py +368 -0
- finetrainers/models/ltx_video/__init__.py +2 -0
- finetrainers/models/ltx_video/full_finetune.py +30 -0
- finetrainers/models/ltx_video/lora.py +331 -0
- finetrainers/patches.py +50 -0
- finetrainers/state.py +24 -0
- finetrainers/trainer.py +1207 -0
- finetrainers/utils/__init__.py +13 -0
- finetrainers/utils/checkpointing.py +64 -0
- finetrainers/utils/data_utils.py +35 -0
- finetrainers/utils/diffusion_utils.py +145 -0
- finetrainers/utils/file_utils.py +44 -0
- finetrainers/utils/hub_utils.py +84 -0
- finetrainers/utils/memory_utils.py +58 -0
- finetrainers/utils/model_utils.py +25 -0
- finetrainers/utils/optimizer_utils.py +178 -0
- finetrainers/utils/torch_utils.py +35 -0
- finetrainers_utils.py +126 -0
- image_preprocessing.py +116 -0
- import_service.py +245 -0
- requirements.txt +43 -0
- requirements_without_flash_attention.txt +42 -0
- run.sh +5 -0
- setup.sh +7 -0
- setup_no_captions.sh +12 -0
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
.venv
|
3 |
+
.data
|
4 |
+
__pycache__
|
5 |
+
*.mp3
|
6 |
+
*.mp4
|
7 |
+
*.zip
|
8 |
+
training_service.log
|
Dockerfile
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
|
2 |
+
|
3 |
+
# Prevent interactive prompts during build
|
4 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
5 |
+
|
6 |
+
# Set environment variables
|
7 |
+
ENV PYTHONUNBUFFERED=1
|
8 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
9 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
10 |
+
|
11 |
+
# Install system dependencies
|
12 |
+
RUN apt-get update && apt-get install -y \
|
13 |
+
python3.10 \
|
14 |
+
python3-pip \
|
15 |
+
python3-dev \
|
16 |
+
git \
|
17 |
+
ffmpeg \
|
18 |
+
libsm6 \
|
19 |
+
libxext6 \
|
20 |
+
libgl1-mesa-glx \
|
21 |
+
libglib2.0-0 \
|
22 |
+
&& apt-get clean \
|
23 |
+
&& rm -rf /var/lib/apt/lists/*
|
24 |
+
|
25 |
+
# Create app directory
|
26 |
+
WORKDIR /app
|
27 |
+
|
28 |
+
# Install Python dependencies first for better caching
|
29 |
+
COPY requirements.txt .
|
30 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
31 |
+
|
32 |
+
# actually we found a way to put flash attention inside the requirements.txt
|
33 |
+
# so we are good, we don't need this anymore:
|
34 |
+
# RUN pip3 install --no-cache-dir -r requirements_without_flash_attention.txt
|
35 |
+
# RUN pip3 install wheel setuptools flash-attn --no-build-isolation --no-cache-dir
|
36 |
+
|
37 |
+
# Copy application files
|
38 |
+
COPY . .
|
39 |
+
|
40 |
+
# Expose Gradio port
|
41 |
+
EXPOSE 7860
|
42 |
+
|
43 |
+
# Run the application
|
44 |
+
CMD ["python3", "app.py"]
|
README.md
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Video Model Studio
|
3 |
+
emoji: 🎥
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.15.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
license: apache-2.0
|
11 |
+
short_description: All-in-one tool for AI video training
|
12 |
+
---
|
13 |
+
|
14 |
+
# 🎥 Video Model Studio (VMS)
|
15 |
+
|
16 |
+
## Presentation
|
17 |
+
|
18 |
+
VMS is an all-in-one tool to train LoRA models for various open-source AI video models:
|
19 |
+
|
20 |
+
- Data collection from various sources
|
21 |
+
- Splitting videos into short single camera shots
|
22 |
+
- Automatic captioning
|
23 |
+
- Training HunyuanVideo or LTX-Video
|
24 |
+
|
25 |
+
## Similar projects
|
26 |
+
|
27 |
+
I wasn't aware of it when I started this project,
|
28 |
+
but there is also this: https://github.com/alisson-anjos/diffusion-pipe-ui
|
29 |
+
|
30 |
+
## Installation
|
31 |
+
|
32 |
+
VMS is built on top of Finetrainers and Gradio, and designed to run as a Hugging Face Space (but you can deploy it elsewhere if you want to).
|
33 |
+
|
34 |
+
### Full installation at Hugging Face
|
35 |
+
|
36 |
+
Easy peasy: create a Space (make sure to use the `Gradio` type/template), and push the repo. No Docker needed!
|
37 |
+
|
38 |
+
### Dev mode on Hugging Face
|
39 |
+
|
40 |
+
Enable dev mode in the space, then open VSCode in local or remote and run:
|
41 |
+
|
42 |
+
```
|
43 |
+
pip install -r requirements.txt
|
44 |
+
```
|
45 |
+
|
46 |
+
As this is not automatic, then click on "Restart" in the space dev mode UI widget.
|
47 |
+
|
48 |
+
### Full installation somewhere else
|
49 |
+
|
50 |
+
I haven't tested it, but you can try to provided Dockerfile
|
51 |
+
|
52 |
+
### Full installation in local
|
53 |
+
|
54 |
+
the full installation requires:
|
55 |
+
- Linux
|
56 |
+
- CUDA 12
|
57 |
+
- Python 3.10
|
58 |
+
|
59 |
+
This is because of flash attention, which is defined in the `requirements.txt` using an URL to download a prebuilt wheel (python bindings for a native library)
|
60 |
+
|
61 |
+
```bash
|
62 |
+
./setup.sh
|
63 |
+
```
|
64 |
+
|
65 |
+
### Degraded installation in local
|
66 |
+
|
67 |
+
If you cannot meet the requirements, you can:
|
68 |
+
|
69 |
+
- solution 1: fix requirements.txt to use another prebuilt wheel
|
70 |
+
- solution 2: manually build/install flash attention
|
71 |
+
- solution 3: don't use clip captioning
|
72 |
+
|
73 |
+
Here is how to do solution 3:
|
74 |
+
```bash
|
75 |
+
./setup_no_captions.sh
|
76 |
+
```
|
77 |
+
|
78 |
+
## Run
|
79 |
+
|
80 |
+
### Running the Gradio app
|
81 |
+
|
82 |
+
Note: please make sure you properly define the environment variables for `STORAGE_PATH` (eg. `/data/`) and `HF_HOME` (eg. `/data/huggingface/`)
|
83 |
+
|
84 |
+
```bash
|
85 |
+
python app.py
|
86 |
+
```
|
87 |
+
|
88 |
+
### Running locally
|
89 |
+
|
90 |
+
See above remarks about the environment variable.
|
91 |
+
|
92 |
+
By default `run.sh` will store stuff in `.data/` (located inside the current working directory):
|
93 |
+
|
94 |
+
```bash
|
95 |
+
./run.sh
|
96 |
+
```
|
accelerate_configs/compiled_1.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: 'NO'
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
dynamo_config:
|
6 |
+
dynamo_backend: INDUCTOR
|
7 |
+
dynamo_mode: max-autotune
|
8 |
+
dynamo_use_dynamic: true
|
9 |
+
dynamo_use_fullgraph: false
|
10 |
+
enable_cpu_affinity: false
|
11 |
+
gpu_ids: '3'
|
12 |
+
machine_rank: 0
|
13 |
+
main_training_function: main
|
14 |
+
mixed_precision: bf16
|
15 |
+
num_machines: 1
|
16 |
+
num_processes: 1
|
17 |
+
rdzv_backend: static
|
18 |
+
same_network: true
|
19 |
+
tpu_env: []
|
20 |
+
tpu_use_cluster: false
|
21 |
+
tpu_use_sudo: false
|
22 |
+
use_cpu: false
|
accelerate_configs/deepspeed.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
deepspeed_config:
|
4 |
+
gradient_accumulation_steps: 1
|
5 |
+
gradient_clipping: 1.0
|
6 |
+
offload_optimizer_device: cpu
|
7 |
+
offload_param_device: cpu
|
8 |
+
zero3_init_flag: false
|
9 |
+
zero_stage: 2
|
10 |
+
distributed_type: DEEPSPEED
|
11 |
+
downcast_bf16: 'no'
|
12 |
+
enable_cpu_affinity: false
|
13 |
+
machine_rank: 0
|
14 |
+
main_training_function: main
|
15 |
+
mixed_precision: bf16
|
16 |
+
num_machines: 1
|
17 |
+
num_processes: 2
|
18 |
+
rdzv_backend: static
|
19 |
+
same_network: true
|
20 |
+
tpu_env: []
|
21 |
+
tpu_use_cluster: false
|
22 |
+
tpu_use_sudo: false
|
23 |
+
use_cpu: false
|
accelerate_configs/uncompiled_1.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: 'NO'
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
enable_cpu_affinity: false
|
6 |
+
gpu_ids: '3'
|
7 |
+
machine_rank: 0
|
8 |
+
main_training_function: main
|
9 |
+
mixed_precision: bf16
|
10 |
+
num_machines: 1
|
11 |
+
num_processes: 1
|
12 |
+
rdzv_backend: static
|
13 |
+
same_network: true
|
14 |
+
tpu_env: []
|
15 |
+
tpu_use_cluster: false
|
16 |
+
tpu_use_sudo: false
|
17 |
+
use_cpu: false
|
accelerate_configs/uncompiled_2.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
enable_cpu_affinity: false
|
6 |
+
gpu_ids: 0,1
|
7 |
+
machine_rank: 0
|
8 |
+
main_training_function: main
|
9 |
+
mixed_precision: bf16
|
10 |
+
num_machines: 1
|
11 |
+
num_processes: 2
|
12 |
+
rdzv_backend: static
|
13 |
+
same_network: true
|
14 |
+
tpu_env: []
|
15 |
+
tpu_use_cluster: false
|
16 |
+
tpu_use_sudo: false
|
17 |
+
use_cpu: false
|
accelerate_configs/uncompiled_8.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
enable_cpu_affinity: false
|
6 |
+
gpu_ids: all
|
7 |
+
machine_rank: 0
|
8 |
+
main_training_function: main
|
9 |
+
mixed_precision: bf16
|
10 |
+
num_machines: 1
|
11 |
+
num_processes: 8
|
12 |
+
rdzv_backend: static
|
13 |
+
same_network: true
|
14 |
+
tpu_env: []
|
15 |
+
tpu_use_cluster: false
|
16 |
+
tpu_use_sudo: false
|
17 |
+
use_cpu: false
|
app.py
ADDED
@@ -0,0 +1,1270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
import subprocess
|
3 |
+
|
4 |
+
#import sys
|
5 |
+
#print("python = ", sys.version)
|
6 |
+
|
7 |
+
# can be "Linux", "Darwin"
|
8 |
+
if platform.system() == "Linux":
|
9 |
+
# for some reason it says "pip not found"
|
10 |
+
# and also "pip3 not found"
|
11 |
+
# subprocess.run(
|
12 |
+
# "pip install flash-attn --no-build-isolation",
|
13 |
+
#
|
14 |
+
# # hmm... this should be False, since we are in a CUDA environment, no?
|
15 |
+
# env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
16 |
+
#
|
17 |
+
# shell=True,
|
18 |
+
# )
|
19 |
+
pass
|
20 |
+
|
21 |
+
import gradio as gr
|
22 |
+
from pathlib import Path
|
23 |
+
import logging
|
24 |
+
import mimetypes
|
25 |
+
import shutil
|
26 |
+
import os
|
27 |
+
import traceback
|
28 |
+
import asyncio
|
29 |
+
import tempfile
|
30 |
+
import zipfile
|
31 |
+
from typing import Any, Optional, Dict, List, Union, Tuple
|
32 |
+
from typing import AsyncGenerator
|
33 |
+
from training_service import TrainingService
|
34 |
+
from captioning_service import CaptioningService
|
35 |
+
from splitting_service import SplittingService
|
36 |
+
from import_service import ImportService
|
37 |
+
from config import (
|
38 |
+
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
|
39 |
+
TRAINING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
|
40 |
+
DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, TRAINING_BUCKETS
|
41 |
+
)
|
42 |
+
from utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time
|
43 |
+
from finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset
|
44 |
+
from training_log_parser import TrainingLogParser
|
45 |
+
|
46 |
+
logger = logging.getLogger(__name__)
|
47 |
+
logger.setLevel(logging.INFO)
|
48 |
+
|
49 |
+
httpx_logger = logging.getLogger('httpx')
|
50 |
+
httpx_logger.setLevel(logging.WARN)
|
51 |
+
|
52 |
+
|
53 |
+
class VideoTrainerUI:
|
54 |
+
def __init__(self):
|
55 |
+
self.trainer = TrainingService()
|
56 |
+
self.splitter = SplittingService()
|
57 |
+
self.importer = ImportService()
|
58 |
+
self.captioner = CaptioningService()
|
59 |
+
self._should_stop_captioning = False
|
60 |
+
self.log_parser = TrainingLogParser()
|
61 |
+
|
62 |
+
def update_training_ui(self, training_state: Dict[str, Any]):
|
63 |
+
"""Update UI components based on training state"""
|
64 |
+
updates = {}
|
65 |
+
|
66 |
+
# Update status box with high-level information
|
67 |
+
status_text = []
|
68 |
+
if training_state["status"] != "idle":
|
69 |
+
status_text.extend([
|
70 |
+
f"Status: {training_state['status']}",
|
71 |
+
f"Progress: {training_state['progress']}",
|
72 |
+
f"Step: {training_state['current_step']}/{training_state['total_steps']}",
|
73 |
+
|
74 |
+
# Epoch information
|
75 |
+
# there is an issue with how epoch is reported because we display:
|
76 |
+
# Progress: 96.9%, Step: 872/900, Epoch: 12/50
|
77 |
+
# we should probably just show the steps
|
78 |
+
#f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}",
|
79 |
+
|
80 |
+
f"Time elapsed: {training_state['elapsed']}",
|
81 |
+
f"Estimated remaining: {training_state['remaining']}",
|
82 |
+
"",
|
83 |
+
f"Current loss: {training_state['step_loss']}",
|
84 |
+
f"Learning rate: {training_state['learning_rate']}",
|
85 |
+
f"Gradient norm: {training_state['grad_norm']}",
|
86 |
+
f"Memory usage: {training_state['memory']}"
|
87 |
+
])
|
88 |
+
|
89 |
+
if training_state["error_message"]:
|
90 |
+
status_text.append(f"\nError: {training_state['error_message']}")
|
91 |
+
|
92 |
+
updates["status_box"] = "\n".join(status_text)
|
93 |
+
|
94 |
+
# Update button states
|
95 |
+
updates["start_btn"] = gr.Button(
|
96 |
+
"Start training",
|
97 |
+
interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]),
|
98 |
+
variant="primary" if training_state["status"] == "idle" else "secondary"
|
99 |
+
)
|
100 |
+
|
101 |
+
updates["stop_btn"] = gr.Button(
|
102 |
+
"Stop training",
|
103 |
+
interactive=(training_state["status"] in ["training", "initializing"]),
|
104 |
+
variant="stop"
|
105 |
+
)
|
106 |
+
|
107 |
+
return updates
|
108 |
+
|
109 |
+
def stop_all_and_clear(self) -> Dict[str, str]:
|
110 |
+
"""Stop all running processes and clear data
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Dict with status messages for different components
|
114 |
+
"""
|
115 |
+
status_messages = {}
|
116 |
+
|
117 |
+
try:
|
118 |
+
# Stop training if running
|
119 |
+
if self.trainer.is_training_running():
|
120 |
+
training_result = self.trainer.stop_training()
|
121 |
+
status_messages["training"] = training_result["status"]
|
122 |
+
|
123 |
+
# Stop captioning if running
|
124 |
+
if self.captioner:
|
125 |
+
self.captioner.stop_captioning()
|
126 |
+
#self.captioner.close()
|
127 |
+
#self.captioner = None
|
128 |
+
status_messages["captioning"] = "Captioning stopped"
|
129 |
+
|
130 |
+
# Stop scene detection if running
|
131 |
+
if self.splitter.is_processing():
|
132 |
+
self.splitter.processing = False
|
133 |
+
status_messages["splitting"] = "Scene detection stopped"
|
134 |
+
|
135 |
+
# Clear all data directories
|
136 |
+
for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
|
137 |
+
MODEL_PATH, OUTPUT_PATH]:
|
138 |
+
if path.exists():
|
139 |
+
try:
|
140 |
+
shutil.rmtree(path)
|
141 |
+
path.mkdir(parents=True, exist_ok=True)
|
142 |
+
except Exception as e:
|
143 |
+
status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}"
|
144 |
+
else:
|
145 |
+
status_messages[f"clear_{path.name}"] = f"Cleared {path.name}"
|
146 |
+
|
147 |
+
# Reset any persistent state
|
148 |
+
self._should_stop_captioning = True
|
149 |
+
self.splitter.processing = False
|
150 |
+
|
151 |
+
return {
|
152 |
+
"status": "All processes stopped and data cleared",
|
153 |
+
"details": status_messages
|
154 |
+
}
|
155 |
+
|
156 |
+
except Exception as e:
|
157 |
+
return {
|
158 |
+
"status": f"Error during cleanup: {str(e)}",
|
159 |
+
"details": status_messages
|
160 |
+
}
|
161 |
+
|
162 |
+
def update_titles(self) -> Tuple[Any]:
|
163 |
+
"""Update all dynamic titles with current counts
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Dict of Gradio updates
|
167 |
+
"""
|
168 |
+
# Count files for splitting
|
169 |
+
split_videos, _, split_size = count_media_files(VIDEOS_TO_SPLIT_PATH)
|
170 |
+
split_title = format_media_title(
|
171 |
+
"split", split_videos, 0, split_size
|
172 |
+
)
|
173 |
+
|
174 |
+
# Count files for captioning
|
175 |
+
caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH)
|
176 |
+
caption_title = format_media_title(
|
177 |
+
"caption", caption_videos, caption_images, caption_size
|
178 |
+
)
|
179 |
+
|
180 |
+
# Count files for training
|
181 |
+
train_videos, train_images, train_size = count_media_files(TRAINING_VIDEOS_PATH)
|
182 |
+
train_title = format_media_title(
|
183 |
+
"train", train_videos, train_images, train_size
|
184 |
+
)
|
185 |
+
|
186 |
+
return (
|
187 |
+
gr.Markdown(value=split_title),
|
188 |
+
gr.Markdown(value=caption_title),
|
189 |
+
gr.Markdown(value=f"{train_title} available for training")
|
190 |
+
)
|
191 |
+
|
192 |
+
def copy_files_to_training_dir(self, prompt_prefix: str):
|
193 |
+
"""Run auto-captioning process"""
|
194 |
+
|
195 |
+
# Initialize captioner if not already done
|
196 |
+
self._should_stop_captioning = False
|
197 |
+
|
198 |
+
try:
|
199 |
+
copy_files_to_training_dir(prompt_prefix)
|
200 |
+
|
201 |
+
except Exception as e:
|
202 |
+
traceback.print_exc()
|
203 |
+
raise gr.Error(f"Error copying assets to training dir: {str(e)}")
|
204 |
+
|
205 |
+
async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]:
|
206 |
+
"""Run auto-captioning process"""
|
207 |
+
try:
|
208 |
+
# Initialize captioner if not already done
|
209 |
+
self._should_stop_captioning = False
|
210 |
+
|
211 |
+
async for rows in self.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix):
|
212 |
+
# Yield UI update
|
213 |
+
yield gr.update(
|
214 |
+
value=rows,
|
215 |
+
headers=["name", "status"]
|
216 |
+
)
|
217 |
+
|
218 |
+
# Final update after completion
|
219 |
+
yield gr.update(
|
220 |
+
value=self.list_training_files_to_caption(),
|
221 |
+
headers=["name", "status"]
|
222 |
+
)
|
223 |
+
|
224 |
+
except Exception as e:
|
225 |
+
yield gr.update(
|
226 |
+
value=[[str(e), "error"]],
|
227 |
+
headers=["name", "status"]
|
228 |
+
)
|
229 |
+
|
230 |
+
def list_training_files_to_caption(self) -> List[List[str]]:
|
231 |
+
"""List all clips and images - both pending and captioned"""
|
232 |
+
files = []
|
233 |
+
already_listed: Dict[str, bool] = {}
|
234 |
+
|
235 |
+
# Check files in STAGING_PATH
|
236 |
+
for file in STAGING_PATH.glob("*.*"):
|
237 |
+
if is_video_file(file) or is_image_file(file):
|
238 |
+
txt_file = file.with_suffix('.txt')
|
239 |
+
status = "captioned" if txt_file.exists() else "no caption"
|
240 |
+
file_type = "video" if is_video_file(file) else "image"
|
241 |
+
files.append([file.name, f"{status} ({file_type})", str(file)])
|
242 |
+
already_listed[str(file.name)] = True
|
243 |
+
|
244 |
+
# Check files in TRAINING_VIDEOS_PATH
|
245 |
+
for file in TRAINING_VIDEOS_PATH.glob("*.*"):
|
246 |
+
if not str(file.name) in already_listed:
|
247 |
+
if is_video_file(file) or is_image_file(file):
|
248 |
+
txt_file = file.with_suffix('.txt')
|
249 |
+
if txt_file.exists():
|
250 |
+
file_type = "video" if is_video_file(file) else "image"
|
251 |
+
files.append([file.name, f"captioned ({file_type})", str(file)])
|
252 |
+
|
253 |
+
# Sort by filename
|
254 |
+
files.sort(key=lambda x: x[0])
|
255 |
+
|
256 |
+
# Only return name and status columns for display
|
257 |
+
return [[file[0], file[1]] for file in files]
|
258 |
+
|
259 |
+
def update_training_buttons(self, training_state: Dict[str, Any]) -> Dict:
|
260 |
+
"""Update training control buttons based on state"""
|
261 |
+
is_training = training_state["status"] in ["training", "initializing"]
|
262 |
+
is_paused = training_state["status"] == "paused"
|
263 |
+
is_completed = training_state["status"] in ["completed", "error", "stopped"]
|
264 |
+
|
265 |
+
return {
|
266 |
+
start_btn: gr.Button(
|
267 |
+
interactive=not is_training and not is_paused,
|
268 |
+
variant="primary" if not is_training else "secondary",
|
269 |
+
),
|
270 |
+
stop_btn: gr.Button(
|
271 |
+
interactive=is_training or is_paused,
|
272 |
+
variant="stop",
|
273 |
+
),
|
274 |
+
pause_resume_btn: gr.Button(
|
275 |
+
value="Resume Training" if is_paused else "Pause Training",
|
276 |
+
interactive=(is_training or is_paused) and not is_completed,
|
277 |
+
variant="secondary",
|
278 |
+
)
|
279 |
+
}
|
280 |
+
|
281 |
+
def handle_training_complete(self):
|
282 |
+
"""Handle training completion"""
|
283 |
+
# Reset button states
|
284 |
+
return self.update_training_buttons({
|
285 |
+
"status": "completed",
|
286 |
+
"progress": "100%",
|
287 |
+
"current_step": 0,
|
288 |
+
"total_steps": 0
|
289 |
+
})
|
290 |
+
|
291 |
+
def handle_pause_resume(self):
|
292 |
+
status = self.trainer.get_status()
|
293 |
+
if status["state"] == "paused":
|
294 |
+
result = self.trainer.resume_training()
|
295 |
+
new_state = {"status": "training"}
|
296 |
+
else:
|
297 |
+
result = self.trainer.pause_training()
|
298 |
+
new_state = {"status": "paused"}
|
299 |
+
return (
|
300 |
+
*result,
|
301 |
+
*self.update_training_buttons(new_state).values()
|
302 |
+
)
|
303 |
+
|
304 |
+
|
305 |
+
def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
306 |
+
"""Handle selection of both video clips and images"""
|
307 |
+
try:
|
308 |
+
if not evt:
|
309 |
+
return [
|
310 |
+
gr.Image(
|
311 |
+
interactive=False,
|
312 |
+
visible=False
|
313 |
+
),
|
314 |
+
gr.Video(
|
315 |
+
interactive=False,
|
316 |
+
visible=False
|
317 |
+
),
|
318 |
+
gr.Textbox(
|
319 |
+
visible=False
|
320 |
+
),
|
321 |
+
"No file selected"
|
322 |
+
]
|
323 |
+
|
324 |
+
file_name = evt.value
|
325 |
+
if not file_name:
|
326 |
+
return [
|
327 |
+
gr.Image(
|
328 |
+
interactive=False,
|
329 |
+
visible=False
|
330 |
+
),
|
331 |
+
gr.Video(
|
332 |
+
interactive=False,
|
333 |
+
visible=False
|
334 |
+
),
|
335 |
+
gr.Textbox(
|
336 |
+
visible=False
|
337 |
+
),
|
338 |
+
"No file selected"
|
339 |
+
]
|
340 |
+
|
341 |
+
# Check both possible locations for the file
|
342 |
+
possible_paths = [
|
343 |
+
STAGING_PATH / file_name,
|
344 |
+
|
345 |
+
# note: we use to look into this dir for already-captioned clips,
|
346 |
+
# but we don't do this anymore
|
347 |
+
#TRAINING_VIDEOS_PATH / file_name
|
348 |
+
]
|
349 |
+
|
350 |
+
# Find the first existing file path
|
351 |
+
file_path = None
|
352 |
+
for path in possible_paths:
|
353 |
+
if path.exists():
|
354 |
+
file_path = path
|
355 |
+
break
|
356 |
+
|
357 |
+
if not file_path:
|
358 |
+
return [
|
359 |
+
gr.Image(
|
360 |
+
interactive=False,
|
361 |
+
visible=False
|
362 |
+
),
|
363 |
+
gr.Video(
|
364 |
+
interactive=False,
|
365 |
+
visible=False
|
366 |
+
),
|
367 |
+
gr.Textbox(
|
368 |
+
visible=False
|
369 |
+
),
|
370 |
+
f"File not found: {file_name}"
|
371 |
+
]
|
372 |
+
|
373 |
+
txt_path = file_path.with_suffix('.txt')
|
374 |
+
caption = txt_path.read_text() if txt_path.exists() else ""
|
375 |
+
|
376 |
+
# Handle video files
|
377 |
+
if is_video_file(file_path):
|
378 |
+
return [
|
379 |
+
gr.Image(
|
380 |
+
interactive=False,
|
381 |
+
visible=False
|
382 |
+
),
|
383 |
+
gr.Video(
|
384 |
+
label="Video Preview",
|
385 |
+
interactive=False,
|
386 |
+
visible=True,
|
387 |
+
value=str(file_path)
|
388 |
+
),
|
389 |
+
gr.Textbox(
|
390 |
+
label="Caption",
|
391 |
+
lines=6,
|
392 |
+
interactive=True,
|
393 |
+
visible=True,
|
394 |
+
value=str(caption)
|
395 |
+
),
|
396 |
+
None
|
397 |
+
]
|
398 |
+
# Handle image files
|
399 |
+
elif is_image_file(file_path):
|
400 |
+
return [
|
401 |
+
gr.Image(
|
402 |
+
label="Image Preview",
|
403 |
+
interactive=False,
|
404 |
+
visible=True,
|
405 |
+
value=str(file_path)
|
406 |
+
),
|
407 |
+
gr.Video(
|
408 |
+
interactive=False,
|
409 |
+
visible=False
|
410 |
+
),
|
411 |
+
gr.Textbox(
|
412 |
+
label="Caption",
|
413 |
+
lines=6,
|
414 |
+
interactive=True,
|
415 |
+
visible=True,
|
416 |
+
value=str(caption)
|
417 |
+
),
|
418 |
+
None
|
419 |
+
]
|
420 |
+
else:
|
421 |
+
return [
|
422 |
+
gr.Image(
|
423 |
+
interactive=False,
|
424 |
+
visible=False
|
425 |
+
),
|
426 |
+
gr.Video(
|
427 |
+
interactive=False,
|
428 |
+
visible=False
|
429 |
+
),
|
430 |
+
gr.Textbox(
|
431 |
+
interactive=False,
|
432 |
+
visible=False
|
433 |
+
),
|
434 |
+
f"Unsupported file type: {file_path.suffix}"
|
435 |
+
]
|
436 |
+
except Exception as e:
|
437 |
+
logger.error(f"Error handling selection: {str(e)}")
|
438 |
+
return [
|
439 |
+
gr.Image(
|
440 |
+
interactive=False,
|
441 |
+
visible=False
|
442 |
+
),
|
443 |
+
gr.Video(
|
444 |
+
interactive=False,
|
445 |
+
visible=False
|
446 |
+
),
|
447 |
+
gr.Textbox(
|
448 |
+
interactive=False,
|
449 |
+
visible=False
|
450 |
+
),
|
451 |
+
f"Error handling selection: {str(e)}"
|
452 |
+
]
|
453 |
+
|
454 |
+
def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, prompt_prefix: str):
|
455 |
+
"""Save changes to caption"""
|
456 |
+
try:
|
457 |
+
# Add prefix if not already present
|
458 |
+
if prompt_prefix and not preview_caption.startswith(prompt_prefix):
|
459 |
+
full_caption = f"{prompt_prefix}{preview_caption}"
|
460 |
+
else:
|
461 |
+
full_caption = preview_caption
|
462 |
+
|
463 |
+
path = Path(preview_video if preview_video else preview_image)
|
464 |
+
if path.suffix == '.txt':
|
465 |
+
self.trainer.update_file_caption(path.with_suffix(''), full_caption)
|
466 |
+
else:
|
467 |
+
self.trainer.update_file_caption(path, full_caption)
|
468 |
+
return gr.update(value="Caption saved successfully!")
|
469 |
+
except Exception as e:
|
470 |
+
return gr.update(value=f"Error saving caption: {str(e)}")
|
471 |
+
|
472 |
+
def get_model_info(self, model_type: str) -> str:
|
473 |
+
"""Get information about the selected model type"""
|
474 |
+
if model_type == "hunyuan_video":
|
475 |
+
return """### HunyuanVideo (LoRA)
|
476 |
+
- Best for learning complex video generation patterns
|
477 |
+
- Required VRAM: ~47GB minimum
|
478 |
+
- Recommended batch size: 1-2
|
479 |
+
- Typical training time: 2-4 hours
|
480 |
+
- Default resolution: 49x512x768
|
481 |
+
- Default LoRA rank: 128"""
|
482 |
+
|
483 |
+
elif model_type == "ltx_video":
|
484 |
+
return """### LTX-Video (LoRA)
|
485 |
+
- Lightweight video model
|
486 |
+
- Required VRAM: ~18GB minimum
|
487 |
+
- Recommended batch size: 1-4
|
488 |
+
- Typical training time: 1-3 hours
|
489 |
+
- Default resolution: 49x512x768
|
490 |
+
- Default LoRA rank: 128"""
|
491 |
+
|
492 |
+
return ""
|
493 |
+
|
494 |
+
def get_default_params(self, model_type: str) -> Dict[str, Any]:
|
495 |
+
"""Get default training parameters for model type"""
|
496 |
+
if model_type == "hunyuan_video":
|
497 |
+
return {
|
498 |
+
"num_epochs": 70,
|
499 |
+
"batch_size": 1,
|
500 |
+
"learning_rate": 2e-5,
|
501 |
+
"save_iterations": 500,
|
502 |
+
"video_resolution_buckets": TRAINING_BUCKETS,
|
503 |
+
"video_reshape_mode": "center",
|
504 |
+
"caption_dropout_p": 0.05,
|
505 |
+
"gradient_accumulation_steps": 1,
|
506 |
+
"rank": 128,
|
507 |
+
"lora_alpha": 128
|
508 |
+
}
|
509 |
+
else: # ltx_video
|
510 |
+
return {
|
511 |
+
"num_epochs": 70,
|
512 |
+
"batch_size": 1,
|
513 |
+
"learning_rate": 3e-5,
|
514 |
+
"save_iterations": 500,
|
515 |
+
"video_resolution_buckets": TRAINING_BUCKETS,
|
516 |
+
"video_reshape_mode": "center",
|
517 |
+
"caption_dropout_p": 0.05,
|
518 |
+
"gradient_accumulation_steps": 4,
|
519 |
+
"rank": 128,
|
520 |
+
"lora_alpha": 128
|
521 |
+
}
|
522 |
+
|
523 |
+
def preview_file(self, selected_text: str) -> Dict:
|
524 |
+
"""Generate preview based on selected file
|
525 |
+
|
526 |
+
Args:
|
527 |
+
selected_text: Text of the selected item containing filename
|
528 |
+
|
529 |
+
Returns:
|
530 |
+
Dict with preview content for each preview component
|
531 |
+
"""
|
532 |
+
if not selected_text or "Caption:" in selected_text:
|
533 |
+
return {
|
534 |
+
"video": None,
|
535 |
+
"image": None,
|
536 |
+
"text": None
|
537 |
+
}
|
538 |
+
|
539 |
+
# Extract filename from the preview text (remove size info)
|
540 |
+
filename = selected_text.split(" (")[0].strip()
|
541 |
+
file_path = TRAINING_VIDEOS_PATH / filename
|
542 |
+
|
543 |
+
if not file_path.exists():
|
544 |
+
return {
|
545 |
+
"video": None,
|
546 |
+
"image": None,
|
547 |
+
"text": f"File not found: {filename}"
|
548 |
+
}
|
549 |
+
|
550 |
+
# Detect file type
|
551 |
+
mime_type, _ = mimetypes.guess_type(str(file_path))
|
552 |
+
if not mime_type:
|
553 |
+
return {
|
554 |
+
"video": None,
|
555 |
+
"image": None,
|
556 |
+
"text": f"Unknown file type: {filename}"
|
557 |
+
}
|
558 |
+
|
559 |
+
# Return appropriate preview
|
560 |
+
if mime_type.startswith('video/'):
|
561 |
+
return {
|
562 |
+
"video": str(file_path),
|
563 |
+
"image": None,
|
564 |
+
"text": None
|
565 |
+
}
|
566 |
+
elif mime_type.startswith('image/'):
|
567 |
+
return {
|
568 |
+
"video": None,
|
569 |
+
"image": str(file_path),
|
570 |
+
"text": None
|
571 |
+
}
|
572 |
+
elif mime_type.startswith('text/'):
|
573 |
+
try:
|
574 |
+
text_content = file_path.read_text()
|
575 |
+
return {
|
576 |
+
"video": None,
|
577 |
+
"image": None,
|
578 |
+
"text": text_content
|
579 |
+
}
|
580 |
+
except Exception as e:
|
581 |
+
return {
|
582 |
+
"video": None,
|
583 |
+
"image": None,
|
584 |
+
"text": f"Error reading file: {str(e)}"
|
585 |
+
}
|
586 |
+
else:
|
587 |
+
return {
|
588 |
+
"video": None,
|
589 |
+
"image": None,
|
590 |
+
"text": f"Unsupported file type: {mime_type}"
|
591 |
+
}
|
592 |
+
|
593 |
+
def list_unprocessed_videos(self) -> gr.Dataframe:
|
594 |
+
"""Update list of unprocessed videos"""
|
595 |
+
videos = self.splitter.list_unprocessed_videos()
|
596 |
+
# videos is already in [[name, status]] format from splitting_service
|
597 |
+
return gr.Dataframe(
|
598 |
+
headers=["name", "status"],
|
599 |
+
value=videos,
|
600 |
+
interactive=False
|
601 |
+
)
|
602 |
+
|
603 |
+
async def start_scene_detection(self, enable_splitting: bool) -> str:
|
604 |
+
"""Start background scene detection process
|
605 |
+
|
606 |
+
Args:
|
607 |
+
enable_splitting: Whether to split videos into scenes
|
608 |
+
"""
|
609 |
+
if self.splitter.is_processing():
|
610 |
+
return "Scene detection already running"
|
611 |
+
|
612 |
+
try:
|
613 |
+
await self.splitter.start_processing(enable_splitting)
|
614 |
+
return "Scene detection completed"
|
615 |
+
except Exception as e:
|
616 |
+
return f"Error during scene detection: {str(e)}"
|
617 |
+
|
618 |
+
|
619 |
+
def refresh_training_status_and_logs(self):
|
620 |
+
"""Refresh all dynamic lists and training state"""
|
621 |
+
status = self.trainer.get_status()
|
622 |
+
logs = self.trainer.get_logs()
|
623 |
+
|
624 |
+
status_update = status["message"]
|
625 |
+
|
626 |
+
# Parse new log lines
|
627 |
+
if logs:
|
628 |
+
last_state = None
|
629 |
+
for line in logs.splitlines():
|
630 |
+
state_update = self.log_parser.parse_line(line)
|
631 |
+
if state_update:
|
632 |
+
last_state = state_update
|
633 |
+
|
634 |
+
if last_state:
|
635 |
+
ui_updates = self.update_training_ui(last_state)
|
636 |
+
status_update = ui_updates.get("status_box", status["message"])
|
637 |
+
|
638 |
+
return (status_update, logs)
|
639 |
+
|
640 |
+
def refresh_training_status(self):
|
641 |
+
"""Refresh training status and update UI"""
|
642 |
+
status, logs = self.refresh_training_status_and_logs()
|
643 |
+
|
644 |
+
# Parse status for training state
|
645 |
+
is_completed = "completed" in status.lower() or "100.0%" in status
|
646 |
+
current_state = {
|
647 |
+
"status": "completed" if is_completed else "training",
|
648 |
+
"message": status
|
649 |
+
}
|
650 |
+
|
651 |
+
if is_completed:
|
652 |
+
button_updates = self.handle_training_complete()
|
653 |
+
return (
|
654 |
+
status,
|
655 |
+
logs,
|
656 |
+
*button_updates.values()
|
657 |
+
)
|
658 |
+
|
659 |
+
# Update based on current training state
|
660 |
+
button_updates = self.update_training_buttons(current_state)
|
661 |
+
return (
|
662 |
+
status,
|
663 |
+
logs,
|
664 |
+
*button_updates.values()
|
665 |
+
)
|
666 |
+
|
667 |
+
def refresh_dataset(self):
|
668 |
+
"""Refresh all dynamic lists and training state"""
|
669 |
+
video_list = self.splitter.list_unprocessed_videos()
|
670 |
+
training_dataset = self.list_training_files_to_caption()
|
671 |
+
|
672 |
+
return (
|
673 |
+
video_list,
|
674 |
+
training_dataset
|
675 |
+
)
|
676 |
+
|
677 |
+
def create_ui(self):
|
678 |
+
"""Create Gradio interface"""
|
679 |
+
|
680 |
+
with gr.Blocks(title="🎥 Video Model Studio") as app:
|
681 |
+
gr.Markdown("# 🎥 Video Model Studio")
|
682 |
+
|
683 |
+
with gr.Tabs() as tabs:
|
684 |
+
with gr.TabItem("1️⃣ Import", id="import_tab"):
|
685 |
+
|
686 |
+
with gr.Row():
|
687 |
+
gr.Markdown("## Optional: automated data cleaning")
|
688 |
+
|
689 |
+
with gr.Row():
|
690 |
+
enable_automatic_video_split = gr.Checkbox(
|
691 |
+
label="Automatically split videos into smaller clips",
|
692 |
+
info="Note: a clip is a single camera shot, usually a few seconds",
|
693 |
+
value=True,
|
694 |
+
visible=False
|
695 |
+
)
|
696 |
+
enable_automatic_content_captioning = gr.Checkbox(
|
697 |
+
label="Automatically caption photos and videos",
|
698 |
+
info="Note: this uses LlaVA and takes some extra time to load and process",
|
699 |
+
value=False,
|
700 |
+
visible=False,
|
701 |
+
)
|
702 |
+
|
703 |
+
with gr.Row():
|
704 |
+
with gr.Column(scale=3):
|
705 |
+
with gr.Row():
|
706 |
+
with gr.Column():
|
707 |
+
gr.Markdown("## Import video files")
|
708 |
+
gr.Markdown("You can upload either:")
|
709 |
+
gr.Markdown("- A single MP4 video file")
|
710 |
+
gr.Markdown("- A ZIP archive containing multiple videos and optional caption files")
|
711 |
+
gr.Markdown("For ZIP files: Create a folder containing videos (name is not important) and optional caption files with the same name (eg. `some_video.txt` for `some_video.mp4`)")
|
712 |
+
|
713 |
+
with gr.Row():
|
714 |
+
files = gr.Files(
|
715 |
+
label="Upload Images, Videos or ZIP",
|
716 |
+
#file_count="multiple",
|
717 |
+
file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip"],
|
718 |
+
type="filepath"
|
719 |
+
)
|
720 |
+
|
721 |
+
with gr.Column(scale=3):
|
722 |
+
with gr.Row():
|
723 |
+
with gr.Column():
|
724 |
+
gr.Markdown("## Import a YouTube video")
|
725 |
+
gr.Markdown("You can also use a YouTube video as reference, by pasting its URL here:")
|
726 |
+
|
727 |
+
with gr.Row():
|
728 |
+
youtube_url = gr.Textbox(
|
729 |
+
label="Import YouTube Video",
|
730 |
+
placeholder="https://www.youtube.com/watch?v=..."
|
731 |
+
)
|
732 |
+
with gr.Row():
|
733 |
+
youtube_download_btn = gr.Button("Download YouTube Video", variant="secondary")
|
734 |
+
with gr.Row():
|
735 |
+
import_status = gr.Textbox(label="Status", interactive=False)
|
736 |
+
|
737 |
+
|
738 |
+
with gr.TabItem("2️⃣ Split", id="split_tab"):
|
739 |
+
with gr.Row():
|
740 |
+
split_title = gr.Markdown("## Splitting of 0 videos (0 bytes)")
|
741 |
+
|
742 |
+
with gr.Row():
|
743 |
+
with gr.Column():
|
744 |
+
detect_btn = gr.Button("Split videos into single-camera shots", variant="primary")
|
745 |
+
detect_status = gr.Textbox(label="Status", interactive=False)
|
746 |
+
|
747 |
+
with gr.Column():
|
748 |
+
|
749 |
+
video_list = gr.Dataframe(
|
750 |
+
headers=["name", "status"],
|
751 |
+
label="Videos to split",
|
752 |
+
interactive=False,
|
753 |
+
wrap=True,
|
754 |
+
#selection_mode="cell" # Enable cell selection
|
755 |
+
)
|
756 |
+
|
757 |
+
|
758 |
+
with gr.TabItem("3️⃣ Caption"):
|
759 |
+
with gr.Row():
|
760 |
+
caption_title = gr.Markdown("## Captioning of 0 files (0 bytes)")
|
761 |
+
|
762 |
+
with gr.Row():
|
763 |
+
|
764 |
+
with gr.Column():
|
765 |
+
with gr.Row():
|
766 |
+
custom_prompt_prefix = gr.Textbox(
|
767 |
+
scale=3,
|
768 |
+
label='Prefix to add to ALL captions (eg. "In the style of TOK, ")',
|
769 |
+
placeholder="In the style of TOK, ",
|
770 |
+
lines=2,
|
771 |
+
value=DEFAULT_PROMPT_PREFIX
|
772 |
+
)
|
773 |
+
captioning_bot_instructions = gr.Textbox(
|
774 |
+
scale=6,
|
775 |
+
label="System instructions for the automatic captioning model",
|
776 |
+
placeholder="Please generate a full description of...",
|
777 |
+
lines=5,
|
778 |
+
value=DEFAULT_CAPTIONING_BOT_INSTRUCTIONS
|
779 |
+
)
|
780 |
+
with gr.Row():
|
781 |
+
run_autocaption_btn = gr.Button(
|
782 |
+
"Automatically fill missing captions",
|
783 |
+
variant="primary" # Makes it green by default
|
784 |
+
)
|
785 |
+
copy_files_to_training_dir_btn = gr.Button(
|
786 |
+
"Copy assets to training directory",
|
787 |
+
variant="primary" # Makes it green by default
|
788 |
+
)
|
789 |
+
stop_autocaption_btn = gr.Button(
|
790 |
+
"Stop Captioning",
|
791 |
+
variant="stop", # Red when enabled
|
792 |
+
interactive=False # Disabled by default
|
793 |
+
)
|
794 |
+
|
795 |
+
with gr.Row():
|
796 |
+
with gr.Column():
|
797 |
+
training_dataset = gr.Dataframe(
|
798 |
+
headers=["name", "status"],
|
799 |
+
interactive=False,
|
800 |
+
wrap=True,
|
801 |
+
value=self.list_training_files_to_caption(),
|
802 |
+
row_count=10, # Optional: set a reasonable row count
|
803 |
+
#selection_mode="cell"
|
804 |
+
)
|
805 |
+
|
806 |
+
with gr.Column():
|
807 |
+
preview_video = gr.Video(
|
808 |
+
label="Video Preview",
|
809 |
+
interactive=False,
|
810 |
+
visible=False
|
811 |
+
)
|
812 |
+
preview_image = gr.Image(
|
813 |
+
label="Image Preview",
|
814 |
+
interactive=False,
|
815 |
+
visible=False
|
816 |
+
)
|
817 |
+
preview_caption = gr.Textbox(
|
818 |
+
label="Caption",
|
819 |
+
lines=6,
|
820 |
+
interactive=True
|
821 |
+
)
|
822 |
+
save_caption_btn = gr.Button("Save Caption")
|
823 |
+
preview_status = gr.Textbox(
|
824 |
+
label="Status",
|
825 |
+
interactive=False,
|
826 |
+
visible=True
|
827 |
+
)
|
828 |
+
|
829 |
+
with gr.TabItem("4️⃣ Train"):
|
830 |
+
with gr.Row():
|
831 |
+
with gr.Column():
|
832 |
+
|
833 |
+
with gr.Row():
|
834 |
+
train_title = gr.Markdown("## 0 files available for training (0 bytes)")
|
835 |
+
|
836 |
+
with gr.Row():
|
837 |
+
with gr.Column():
|
838 |
+
model_type = gr.Dropdown(
|
839 |
+
choices=list(MODEL_TYPES.keys()),
|
840 |
+
label="Model Type",
|
841 |
+
value=list(MODEL_TYPES.keys())[0]
|
842 |
+
)
|
843 |
+
model_info = gr.Markdown(
|
844 |
+
value=self.get_model_info(list(MODEL_TYPES.keys())[0])
|
845 |
+
)
|
846 |
+
|
847 |
+
with gr.Row():
|
848 |
+
lora_rank = gr.Dropdown(
|
849 |
+
label="LoRA Rank",
|
850 |
+
choices=["16", "32", "64", "128", "256"],
|
851 |
+
value="128",
|
852 |
+
type="value"
|
853 |
+
)
|
854 |
+
lora_alpha = gr.Dropdown(
|
855 |
+
label="LoRA Alpha",
|
856 |
+
choices=["16", "32", "64", "128", "256"],
|
857 |
+
value="128",
|
858 |
+
type="value"
|
859 |
+
)
|
860 |
+
with gr.Row():
|
861 |
+
num_epochs = gr.Number(
|
862 |
+
label="Number of Epochs",
|
863 |
+
value=70,
|
864 |
+
minimum=1,
|
865 |
+
precision=0
|
866 |
+
)
|
867 |
+
batch_size = gr.Number(
|
868 |
+
label="Batch Size",
|
869 |
+
value=1,
|
870 |
+
minimum=1,
|
871 |
+
precision=0
|
872 |
+
)
|
873 |
+
with gr.Row():
|
874 |
+
learning_rate = gr.Number(
|
875 |
+
label="Learning Rate",
|
876 |
+
value=2e-5,
|
877 |
+
minimum=1e-7
|
878 |
+
)
|
879 |
+
save_iterations = gr.Number(
|
880 |
+
label="Save checkpoint every N iterations",
|
881 |
+
value=500,
|
882 |
+
minimum=50,
|
883 |
+
precision=0,
|
884 |
+
info="Model will be saved periodically after these many steps"
|
885 |
+
)
|
886 |
+
|
887 |
+
with gr.Column():
|
888 |
+
with gr.Row():
|
889 |
+
start_btn = gr.Button(
|
890 |
+
"Start Training",
|
891 |
+
variant="primary",
|
892 |
+
interactive=not ASK_USER_TO_DUPLICATE_SPACE
|
893 |
+
)
|
894 |
+
pause_resume_btn = gr.Button(
|
895 |
+
"Resume Training",
|
896 |
+
variant="secondary",
|
897 |
+
interactive=False
|
898 |
+
)
|
899 |
+
stop_btn = gr.Button(
|
900 |
+
"Stop Training",
|
901 |
+
variant="stop",
|
902 |
+
interactive=False
|
903 |
+
)
|
904 |
+
|
905 |
+
with gr.Row():
|
906 |
+
with gr.Column():
|
907 |
+
status_box = gr.Textbox(
|
908 |
+
label="Training Status",
|
909 |
+
interactive=False,
|
910 |
+
lines=4
|
911 |
+
)
|
912 |
+
log_box = gr.TextArea(
|
913 |
+
label="Training Logs",
|
914 |
+
interactive=False,
|
915 |
+
lines=10,
|
916 |
+
max_lines=40,
|
917 |
+
autoscroll=True
|
918 |
+
)
|
919 |
+
|
920 |
+
with gr.TabItem("5️⃣ Manage"):
|
921 |
+
|
922 |
+
with gr.Column():
|
923 |
+
with gr.Row():
|
924 |
+
with gr.Column():
|
925 |
+
gr.Markdown("## Publishing")
|
926 |
+
gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
|
927 |
+
|
928 |
+
with gr.Row():
|
929 |
+
|
930 |
+
with gr.Column():
|
931 |
+
repo_id = gr.Textbox(
|
932 |
+
label="HuggingFace Model Repository",
|
933 |
+
placeholder="username/model-name",
|
934 |
+
info="The repository will be created if it doesn't exist"
|
935 |
+
)
|
936 |
+
gr.Checkbox(label="Check this to make your model public (ie. visible and downloadable by anyone)", info="You model is private by default"),
|
937 |
+
global_stop_btn = gr.Button(
|
938 |
+
"Push my model",
|
939 |
+
#variant="stop"
|
940 |
+
)
|
941 |
+
|
942 |
+
|
943 |
+
with gr.Row():
|
944 |
+
with gr.Column():
|
945 |
+
with gr.Row():
|
946 |
+
with gr.Column():
|
947 |
+
gr.Markdown("## Storage management")
|
948 |
+
with gr.Row():
|
949 |
+
download_dataset_btn = gr.DownloadButton(
|
950 |
+
"Download dataset",
|
951 |
+
variant="secondary",
|
952 |
+
size="lg"
|
953 |
+
)
|
954 |
+
download_model_btn = gr.DownloadButton(
|
955 |
+
"Download model",
|
956 |
+
variant="secondary",
|
957 |
+
size="lg"
|
958 |
+
)
|
959 |
+
|
960 |
+
|
961 |
+
with gr.Row():
|
962 |
+
global_stop_btn = gr.Button(
|
963 |
+
"Stop everything and delete my data",
|
964 |
+
variant="stop"
|
965 |
+
)
|
966 |
+
global_status = gr.Textbox(
|
967 |
+
label="Global Status",
|
968 |
+
interactive=False,
|
969 |
+
visible=False
|
970 |
+
)
|
971 |
+
|
972 |
+
|
973 |
+
|
974 |
+
# Event handlers
|
975 |
+
def update_model_info(model):
|
976 |
+
params = self.get_default_params(MODEL_TYPES[model])
|
977 |
+
info = self.get_model_info(MODEL_TYPES[model])
|
978 |
+
return {
|
979 |
+
model_info: info,
|
980 |
+
num_epochs: params["num_epochs"],
|
981 |
+
batch_size: params["batch_size"],
|
982 |
+
learning_rate: params["learning_rate"],
|
983 |
+
save_iterations: params["save_iterations"]
|
984 |
+
}
|
985 |
+
|
986 |
+
def validate_repo(repo_id: str) -> dict:
|
987 |
+
validation = validate_model_repo(repo_id)
|
988 |
+
if validation["error"]:
|
989 |
+
return gr.update(value=repo_id, error=validation["error"])
|
990 |
+
return gr.update(value=repo_id, error=None)
|
991 |
+
|
992 |
+
# Connect events
|
993 |
+
model_type.change(
|
994 |
+
fn=update_model_info,
|
995 |
+
inputs=[model_type],
|
996 |
+
outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations]
|
997 |
+
)
|
998 |
+
|
999 |
+
async def on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
1000 |
+
videos = self.list_unprocessed_videos()
|
1001 |
+
# If scene detection isn't already running and there are videos to process,
|
1002 |
+
# and auto-splitting is enabled, start the detection
|
1003 |
+
if videos and not self.splitter.is_processing() and enable_splitting:
|
1004 |
+
await self.start_scene_detection(enable_splitting)
|
1005 |
+
msg = "Starting automatic scene detection..."
|
1006 |
+
else:
|
1007 |
+
# Just copy files without splitting if auto-split disabled
|
1008 |
+
for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
|
1009 |
+
await self.splitter.process_video(video_file, enable_splitting=False)
|
1010 |
+
msg = "Copying videos without splitting..."
|
1011 |
+
|
1012 |
+
copy_files_to_training_dir(prompt_prefix)
|
1013 |
+
|
1014 |
+
# Start auto-captioning if enabled
|
1015 |
+
if enable_automatic_content_captioning:
|
1016 |
+
await self.start_caption_generation(
|
1017 |
+
DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
|
1018 |
+
prompt_prefix
|
1019 |
+
)
|
1020 |
+
|
1021 |
+
return {
|
1022 |
+
tabs: gr.Tabs(selected="split_tab"),
|
1023 |
+
video_list: videos,
|
1024 |
+
detect_status: msg
|
1025 |
+
}
|
1026 |
+
|
1027 |
+
|
1028 |
+
async def update_titles_after_import(enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
1029 |
+
"""Handle post-import updates including titles"""
|
1030 |
+
import_result = await on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix)
|
1031 |
+
titles = self.update_titles()
|
1032 |
+
return (*import_result, *titles)
|
1033 |
+
|
1034 |
+
files.upload(
|
1035 |
+
fn=lambda x: self.importer.process_uploaded_files(x),
|
1036 |
+
inputs=[files],
|
1037 |
+
outputs=[import_status]
|
1038 |
+
).success(
|
1039 |
+
fn=update_titles_after_import,
|
1040 |
+
inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix],
|
1041 |
+
outputs=[
|
1042 |
+
tabs, video_list, detect_status,
|
1043 |
+
split_title, caption_title, train_title
|
1044 |
+
]
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
youtube_download_btn.click(
|
1048 |
+
fn=self.importer.download_youtube_video,
|
1049 |
+
inputs=[youtube_url],
|
1050 |
+
outputs=[import_status]
|
1051 |
+
).success(
|
1052 |
+
fn=on_import_success,
|
1053 |
+
inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix],
|
1054 |
+
outputs=[tabs, video_list, detect_status]
|
1055 |
+
)
|
1056 |
+
|
1057 |
+
# Scene detection events
|
1058 |
+
detect_btn.click(
|
1059 |
+
fn=self.start_scene_detection,
|
1060 |
+
inputs=[enable_automatic_video_split],
|
1061 |
+
outputs=[detect_status]
|
1062 |
+
)
|
1063 |
+
|
1064 |
+
|
1065 |
+
# Update button states based on captioning status
|
1066 |
+
def update_button_states(is_running):
|
1067 |
+
return {
|
1068 |
+
run_autocaption_btn: gr.Button(
|
1069 |
+
interactive=not is_running,
|
1070 |
+
variant="secondary" if is_running else "primary",
|
1071 |
+
),
|
1072 |
+
stop_autocaption_btn: gr.Button(
|
1073 |
+
interactive=is_running,
|
1074 |
+
variant="secondary",
|
1075 |
+
),
|
1076 |
+
}
|
1077 |
+
|
1078 |
+
run_autocaption_btn.click(
|
1079 |
+
fn=self.start_caption_generation,
|
1080 |
+
inputs=[captioning_bot_instructions, custom_prompt_prefix],
|
1081 |
+
outputs=[training_dataset],
|
1082 |
+
).then(
|
1083 |
+
fn=lambda: update_button_states(True),
|
1084 |
+
outputs=[run_autocaption_btn, stop_autocaption_btn]
|
1085 |
+
)
|
1086 |
+
|
1087 |
+
copy_files_to_training_dir_btn.click(
|
1088 |
+
fn=self.copy_files_to_training_dir,
|
1089 |
+
inputs=[custom_prompt_prefix]
|
1090 |
+
)
|
1091 |
+
|
1092 |
+
stop_autocaption_btn.click(
|
1093 |
+
fn=lambda: (self.captioner.stop_captioning() if self.captioner else None, update_button_states(False)),
|
1094 |
+
outputs=[run_autocaption_btn, stop_autocaption_btn]
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
training_dataset.select(
|
1098 |
+
fn=self.handle_training_dataset_select,
|
1099 |
+
outputs=[preview_image, preview_video, preview_caption, preview_status]
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
save_caption_btn.click(
|
1103 |
+
fn=self.save_caption_changes,
|
1104 |
+
inputs=[preview_caption, preview_image, preview_video, custom_prompt_prefix],
|
1105 |
+
outputs=[preview_status]
|
1106 |
+
).success(
|
1107 |
+
fn=self.list_training_files_to_caption,
|
1108 |
+
outputs=[training_dataset]
|
1109 |
+
)
|
1110 |
+
|
1111 |
+
# Training control events
|
1112 |
+
start_btn.click(
|
1113 |
+
fn=lambda model_type, *args: (
|
1114 |
+
self.log_parser.reset(),
|
1115 |
+
self.trainer.start_training(
|
1116 |
+
MODEL_TYPES[model_type],
|
1117 |
+
*args
|
1118 |
+
)
|
1119 |
+
),
|
1120 |
+
inputs=[
|
1121 |
+
model_type,
|
1122 |
+
lora_rank,
|
1123 |
+
lora_alpha,
|
1124 |
+
num_epochs,
|
1125 |
+
batch_size,
|
1126 |
+
learning_rate,
|
1127 |
+
save_iterations,
|
1128 |
+
repo_id
|
1129 |
+
],
|
1130 |
+
outputs=[status_box, log_box]
|
1131 |
+
).success(
|
1132 |
+
fn=lambda: self.update_training_buttons({
|
1133 |
+
"status": "training"
|
1134 |
+
}),
|
1135 |
+
outputs=[start_btn, stop_btn, pause_resume_btn]
|
1136 |
+
)
|
1137 |
+
|
1138 |
+
|
1139 |
+
pause_resume_btn.click(
|
1140 |
+
fn=self.handle_pause_resume,
|
1141 |
+
outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
|
1142 |
+
)
|
1143 |
+
|
1144 |
+
stop_btn.click(
|
1145 |
+
fn=self.trainer.stop_training,
|
1146 |
+
outputs=[status_box, log_box]
|
1147 |
+
).success(
|
1148 |
+
fn=self.handle_training_complete,
|
1149 |
+
outputs=[start_btn, stop_btn, pause_resume_btn]
|
1150 |
+
)
|
1151 |
+
|
1152 |
+
def handle_global_stop():
|
1153 |
+
result = self.stop_all_and_clear()
|
1154 |
+
# Update all relevant UI components
|
1155 |
+
status = result["status"]
|
1156 |
+
details = "\n".join(f"{k}: {v}" for k, v in result["details"].items())
|
1157 |
+
full_status = f"{status}\n\nDetails:\n{details}"
|
1158 |
+
|
1159 |
+
# Get fresh lists after cleanup
|
1160 |
+
videos = self.splitter.list_unprocessed_videos()
|
1161 |
+
clips = self.list_training_files_to_caption()
|
1162 |
+
|
1163 |
+
return {
|
1164 |
+
global_status: gr.update(value=full_status, visible=True),
|
1165 |
+
video_list: videos,
|
1166 |
+
training_dataset: clips,
|
1167 |
+
status_box: "Training stopped and data cleared",
|
1168 |
+
log_box: "",
|
1169 |
+
detect_status: "Scene detection stopped",
|
1170 |
+
import_status: "All data cleared",
|
1171 |
+
preview_status: "Captioning stopped"
|
1172 |
+
}
|
1173 |
+
|
1174 |
+
download_dataset_btn.click(
|
1175 |
+
fn=self.trainer.create_training_dataset_zip,
|
1176 |
+
outputs=[download_dataset_btn]
|
1177 |
+
)
|
1178 |
+
|
1179 |
+
download_model_btn.click(
|
1180 |
+
fn=self.trainer.get_model_output_safetensors,
|
1181 |
+
outputs=[download_model_btn]
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
global_stop_btn.click(
|
1185 |
+
fn=handle_global_stop,
|
1186 |
+
outputs=[
|
1187 |
+
global_status,
|
1188 |
+
video_list,
|
1189 |
+
training_dataset,
|
1190 |
+
status_box,
|
1191 |
+
log_box,
|
1192 |
+
detect_status,
|
1193 |
+
import_status,
|
1194 |
+
preview_status
|
1195 |
+
]
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
# Auto-refresh timers
|
1199 |
+
app.load(
|
1200 |
+
fn=lambda: (
|
1201 |
+
self.refresh_dataset()
|
1202 |
+
),
|
1203 |
+
outputs=[
|
1204 |
+
video_list, training_dataset
|
1205 |
+
]
|
1206 |
+
)
|
1207 |
+
|
1208 |
+
timer = gr.Timer(value=1)
|
1209 |
+
timer.tick(
|
1210 |
+
fn=lambda: (
|
1211 |
+
self.refresh_training_status_and_logs()
|
1212 |
+
),
|
1213 |
+
outputs=[
|
1214 |
+
status_box,
|
1215 |
+
log_box
|
1216 |
+
]
|
1217 |
+
)
|
1218 |
+
|
1219 |
+
timer = gr.Timer(value=5)
|
1220 |
+
timer.tick(
|
1221 |
+
fn=lambda: (
|
1222 |
+
self.refresh_dataset()
|
1223 |
+
),
|
1224 |
+
outputs=[
|
1225 |
+
video_list, training_dataset
|
1226 |
+
]
|
1227 |
+
)
|
1228 |
+
|
1229 |
+
timer = gr.Timer(value=5)
|
1230 |
+
timer.tick(
|
1231 |
+
fn=lambda: self.update_titles(),
|
1232 |
+
outputs=[
|
1233 |
+
split_title, caption_title, train_title
|
1234 |
+
]
|
1235 |
+
)
|
1236 |
+
|
1237 |
+
return app
|
1238 |
+
|
1239 |
+
def create_app():
|
1240 |
+
if ASK_USER_TO_DUPLICATE_SPACE:
|
1241 |
+
with gr.Blocks() as app:
|
1242 |
+
gr.Markdown("""# Finetrainers UI
|
1243 |
+
|
1244 |
+
This Hugging Face space needs to be duplicated to your own billing account to work.
|
1245 |
+
|
1246 |
+
Click the 'Duplicate Space' button at the top of the page to create your own copy.
|
1247 |
+
|
1248 |
+
It is recommended to use a Nvidia L40S and a persistent storage space.
|
1249 |
+
To avoid overpaying for your space, you can configure the auto-sleep settings to fit your personal budget.""")
|
1250 |
+
return app
|
1251 |
+
|
1252 |
+
ui = VideoTrainerUI()
|
1253 |
+
return ui.create_ui()
|
1254 |
+
|
1255 |
+
if __name__ == "__main__":
|
1256 |
+
app = create_app()
|
1257 |
+
|
1258 |
+
allowed_paths = [
|
1259 |
+
str(STORAGE_PATH), # Base storage
|
1260 |
+
str(VIDEOS_TO_SPLIT_PATH),
|
1261 |
+
str(STAGING_PATH),
|
1262 |
+
str(TRAINING_PATH),
|
1263 |
+
str(TRAINING_VIDEOS_PATH),
|
1264 |
+
str(MODEL_PATH),
|
1265 |
+
str(OUTPUT_PATH)
|
1266 |
+
]
|
1267 |
+
app.queue(default_concurrency_limit=1).launch(
|
1268 |
+
server_name="0.0.0.0",
|
1269 |
+
allowed_paths=allowed_paths
|
1270 |
+
)
|
captioning_service.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import torch
|
3 |
+
import shutil
|
4 |
+
import gradio as gr
|
5 |
+
from llava.model.builder import load_pretrained_model
|
6 |
+
from llava.mm_utils import tokenizer_image_token
|
7 |
+
import numpy as np
|
8 |
+
from decord import VideoReader, cpu
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Any, Tuple, Dict, Optional, AsyncGenerator, List
|
11 |
+
import asyncio
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from datetime import datetime
|
14 |
+
import cv2
|
15 |
+
from config import TRAINING_VIDEOS_PATH, STAGING_PATH, PRELOAD_CAPTIONING_MODEL, CAPTIONING_MODEL, USE_MOCK_CAPTIONING_MODEL, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX
|
16 |
+
from utils import extract_scene_info, is_image_file, is_video_file
|
17 |
+
from finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class CaptioningProgress:
|
23 |
+
video_name: str
|
24 |
+
total_frames: int
|
25 |
+
processed_frames: int
|
26 |
+
status: str
|
27 |
+
started_at: datetime
|
28 |
+
completed_at: Optional[datetime] = None
|
29 |
+
error: Optional[str] = None
|
30 |
+
|
31 |
+
class CaptioningService:
|
32 |
+
_instance = None
|
33 |
+
_model = None
|
34 |
+
_tokenizer = None
|
35 |
+
_image_processor = None
|
36 |
+
_model_loading = None
|
37 |
+
_loop = None
|
38 |
+
|
39 |
+
def __new__(cls, model_name=CAPTIONING_MODEL):
|
40 |
+
if cls._instance is not None:
|
41 |
+
return cls._instance
|
42 |
+
|
43 |
+
instance = super().__new__(cls)
|
44 |
+
if PRELOAD_CAPTIONING_MODEL:
|
45 |
+
cls._instance = instance
|
46 |
+
try:
|
47 |
+
cls._loop = asyncio.get_running_loop()
|
48 |
+
except RuntimeError:
|
49 |
+
cls._loop = asyncio.new_event_loop()
|
50 |
+
asyncio.set_event_loop(cls._loop)
|
51 |
+
|
52 |
+
if not USE_MOCK_CAPTIONING_MODEL and cls._model_loading is None:
|
53 |
+
cls._model_loading = cls._loop.create_task(cls._background_load_model(model_name))
|
54 |
+
return instance
|
55 |
+
|
56 |
+
def __init__(self, model_name=CAPTIONING_MODEL):
|
57 |
+
if hasattr(self, 'model_name'): # Already initialized
|
58 |
+
return
|
59 |
+
|
60 |
+
self.model_name = model_name
|
61 |
+
self.tokenizer = None
|
62 |
+
self.model = None
|
63 |
+
self.image_processor = None
|
64 |
+
self.active_tasks: Dict[str, CaptioningProgress] = {}
|
65 |
+
self._should_stop = False
|
66 |
+
self._model_loaded = False
|
67 |
+
|
68 |
+
@classmethod
|
69 |
+
async def _background_load_model(cls, model_name):
|
70 |
+
"""Background task to load the model"""
|
71 |
+
try:
|
72 |
+
logger.info("Starting background model loading...")
|
73 |
+
if not cls._loop:
|
74 |
+
cls._loop = asyncio.get_running_loop()
|
75 |
+
|
76 |
+
def load_model():
|
77 |
+
try:
|
78 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(
|
79 |
+
model_name, None, "llava_qwen",
|
80 |
+
torch_dtype="bfloat16", device_map="auto"
|
81 |
+
)
|
82 |
+
model.eval()
|
83 |
+
return tokenizer, model, image_processor
|
84 |
+
except Exception as e:
|
85 |
+
logger.error(f"Error in load_model: {str(e)}")
|
86 |
+
raise
|
87 |
+
|
88 |
+
result = await cls._loop.run_in_executor(None, load_model)
|
89 |
+
|
90 |
+
cls._tokenizer, cls._model, cls._image_processor = result
|
91 |
+
logger.info("Background model loading completed successfully!")
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
logger.error(f"Background model loading failed: {str(e)}")
|
95 |
+
cls._model_loading = None
|
96 |
+
raise
|
97 |
+
|
98 |
+
async def ensure_model_loaded(self):
|
99 |
+
"""Ensure model is loaded before processing"""
|
100 |
+
if USE_MOCK_CAPTIONING_MODEL:
|
101 |
+
logger.info("Using mock model, skipping model loading")
|
102 |
+
self.__class__._model_loading = None
|
103 |
+
self._model_loaded = True
|
104 |
+
return
|
105 |
+
|
106 |
+
if not self._model_loaded:
|
107 |
+
try:
|
108 |
+
if PRELOAD_CAPTIONING_MODEL and self.__class__._model_loading:
|
109 |
+
logger.info("Waiting for background model loading to complete...")
|
110 |
+
if self.__class__._loop and self.__class__._loop != asyncio.get_running_loop():
|
111 |
+
logger.warning("Different event loop detected, creating new loading task")
|
112 |
+
self.__class__._model_loading = None
|
113 |
+
await self._load_model_sync()
|
114 |
+
else:
|
115 |
+
await self.__class__._model_loading
|
116 |
+
self.model = self.__class__._model
|
117 |
+
self.tokenizer = self.__class__._tokenizer
|
118 |
+
self.image_processor = self.__class__._image_processor
|
119 |
+
else:
|
120 |
+
await self._load_model_sync()
|
121 |
+
|
122 |
+
self._model_loaded = True
|
123 |
+
logger.info("Model loading completed!")
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(f"Error loading model: {str(e)}")
|
126 |
+
raise
|
127 |
+
|
128 |
+
async def _load_model_sync(self):
|
129 |
+
"""Synchronously load the model"""
|
130 |
+
logger.info("Loading model synchronously...")
|
131 |
+
current_loop = asyncio.get_running_loop()
|
132 |
+
|
133 |
+
def load_model():
|
134 |
+
return load_pretrained_model(
|
135 |
+
self.model_name, None, "llava_qwen",
|
136 |
+
torch_dtype="bfloat16", device_map="auto"
|
137 |
+
)
|
138 |
+
|
139 |
+
self.tokenizer, self.model, self.image_processor, _ = await current_loop.run_in_executor(
|
140 |
+
None, load_model
|
141 |
+
)
|
142 |
+
self.model.eval()
|
143 |
+
|
144 |
+
def _load_video(self, video_path: Path, max_frames_num: int = 64, fps: int = 1, force_sample: bool = True) -> tuple[np.ndarray, str, float]:
|
145 |
+
"""Load and preprocess video frames"""
|
146 |
+
|
147 |
+
video_path_str = str(video_path) if hasattr(video_path, '__fspath__') else video_path
|
148 |
+
|
149 |
+
logger.debug(f"Loading video: {video_path_str}")
|
150 |
+
|
151 |
+
if max_frames_num == 0:
|
152 |
+
return np.zeros((1, 336, 336, 3)), "", 0
|
153 |
+
|
154 |
+
vr = VideoReader(video_path_str, ctx=cpu(0), num_threads=1)
|
155 |
+
total_frame_num = len(vr)
|
156 |
+
video_time = total_frame_num / vr.get_avg_fps()
|
157 |
+
|
158 |
+
# Calculate frame indices
|
159 |
+
fps = round(vr.get_avg_fps()/fps)
|
160 |
+
frame_idx = [i for i in range(0, len(vr), fps)]
|
161 |
+
frame_time = [i/fps for i in frame_idx]
|
162 |
+
|
163 |
+
if len(frame_idx) > max_frames_num or force_sample:
|
164 |
+
sample_fps = max_frames_num
|
165 |
+
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
|
166 |
+
frame_idx = uniform_sampled_frames.tolist()
|
167 |
+
frame_time = [i/vr.get_avg_fps() for i in frame_idx]
|
168 |
+
|
169 |
+
frame_time_str = ",".join([f"{i:.2f}s" for i in frame_time])
|
170 |
+
|
171 |
+
try:
|
172 |
+
frames = vr.get_batch(frame_idx).asnumpy()
|
173 |
+
logger.debug(f"Loaded {len(frames)} frames with shape {frames.shape}")
|
174 |
+
return frames, frame_time_str, video_time
|
175 |
+
except Exception as e:
|
176 |
+
logger.error(f"Error loading video frames: {str(e)}")
|
177 |
+
raise
|
178 |
+
|
179 |
+
async def process_video(self, video_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]:
|
180 |
+
try:
|
181 |
+
video_name = video_path.name
|
182 |
+
logger.info(f"Starting processing of video: {video_name}")
|
183 |
+
|
184 |
+
# Load video metadata
|
185 |
+
logger.debug(f"Loading video metadata for {video_name}")
|
186 |
+
loop = asyncio.get_event_loop()
|
187 |
+
vr = await loop.run_in_executor(None, lambda: VideoReader(str(video_path), ctx=cpu(0)))
|
188 |
+
total_frames = len(vr)
|
189 |
+
|
190 |
+
progress = CaptioningProgress(
|
191 |
+
video_name=video_name,
|
192 |
+
total_frames=total_frames,
|
193 |
+
processed_frames=0,
|
194 |
+
status="initializing",
|
195 |
+
started_at=datetime.now()
|
196 |
+
)
|
197 |
+
self.active_tasks[video_name] = progress
|
198 |
+
yield progress, None
|
199 |
+
|
200 |
+
# Get parent caption if this is a clip
|
201 |
+
parent_caption = ""
|
202 |
+
if "___" in video_path.stem:
|
203 |
+
parent_name, _ = extract_scene_info(video_path.stem)
|
204 |
+
#print(f"parent_name is {parent_name}")
|
205 |
+
parent_txt_path = VIDEOS_TO_SPLIT_PATH / f"{parent_name}.txt"
|
206 |
+
if parent_txt_path.exists():
|
207 |
+
logger.debug(f"Found parent caption file: {parent_txt_path}")
|
208 |
+
parent_caption = parent_txt_path.read_text().strip()
|
209 |
+
|
210 |
+
# Ensure model is loaded before processing
|
211 |
+
await self.ensure_model_loaded()
|
212 |
+
|
213 |
+
if USE_MOCK_CAPTIONING_MODEL:
|
214 |
+
|
215 |
+
# Even in mock mode, we'll generate a caption that shows we processed parent info
|
216 |
+
clip_caption = f"This is a test caption for {video_name}"
|
217 |
+
|
218 |
+
# Combine clip caption with parent caption
|
219 |
+
if parent_caption and not full_caption.endswith(parent_caption):
|
220 |
+
#print(f"we have parent_caption, so we define the full_caption as {clip_caption}\n{parent_caption}")
|
221 |
+
|
222 |
+
full_caption = f"{clip_caption}\n{parent_caption}"
|
223 |
+
else:
|
224 |
+
#print(f"we don't have a parent_caption, so we define the full_caption as {clip_caption}")
|
225 |
+
|
226 |
+
full_caption = clip_caption
|
227 |
+
|
228 |
+
if prompt_prefix and not full_caption.startswith(prompt_prefix):
|
229 |
+
full_caption = f"{prompt_prefix}{full_caption}"
|
230 |
+
|
231 |
+
# Write the caption file
|
232 |
+
txt_path = video_path.with_suffix('.txt')
|
233 |
+
txt_path.write_text(full_caption)
|
234 |
+
|
235 |
+
logger.debug(f"Mock mode: Saved caption to {txt_path}")
|
236 |
+
|
237 |
+
progress.status = "completed"
|
238 |
+
progress.processed_frames = total_frames
|
239 |
+
progress.completed_at = datetime.now()
|
240 |
+
yield progress, full_caption
|
241 |
+
|
242 |
+
else:
|
243 |
+
# Process frames in batches
|
244 |
+
max_frames_num = 64
|
245 |
+
frames, frame_times_str, video_time = await loop.run_in_executor(
|
246 |
+
None,
|
247 |
+
lambda: self._load_video(video_path, max_frames_num)
|
248 |
+
)
|
249 |
+
|
250 |
+
# Process all frames at once using the image processor
|
251 |
+
processed_frames = await loop.run_in_executor(
|
252 |
+
None,
|
253 |
+
lambda: self.image_processor.preprocess(
|
254 |
+
frames,
|
255 |
+
return_tensors="pt"
|
256 |
+
)["pixel_values"]
|
257 |
+
)
|
258 |
+
|
259 |
+
# Update progress
|
260 |
+
progress.processed_frames = len(frames)
|
261 |
+
progress.status = "generating caption"
|
262 |
+
yield progress, None
|
263 |
+
|
264 |
+
# Move processed frames to GPU
|
265 |
+
video_tensor = processed_frames.to('cuda').bfloat16()
|
266 |
+
|
267 |
+
time_instruction = (f"The video lasts for {video_time:.2f} seconds, and {len(frames)} "
|
268 |
+
f"frames are uniformly sampled from it. These frames are located at {frame_times_str}.")
|
269 |
+
full_prompt = f"<image>{time_instruction}\n{prompt}"
|
270 |
+
|
271 |
+
input_ids = await loop.run_in_executor(
|
272 |
+
None,
|
273 |
+
lambda: tokenizer_image_token(full_prompt, self.tokenizer, return_tensors="pt").unsqueeze(0).to('cuda')
|
274 |
+
)
|
275 |
+
|
276 |
+
# Generate caption
|
277 |
+
with torch.no_grad():
|
278 |
+
output = await loop.run_in_executor(
|
279 |
+
None,
|
280 |
+
lambda: self.model.generate(
|
281 |
+
input_ids,
|
282 |
+
images=[video_tensor],
|
283 |
+
modalities=["video"],
|
284 |
+
do_sample=False,
|
285 |
+
temperature=0,
|
286 |
+
max_new_tokens=4096,
|
287 |
+
)
|
288 |
+
)
|
289 |
+
|
290 |
+
clip_caption = await loop.run_in_executor(
|
291 |
+
None,
|
292 |
+
lambda: self.tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
|
293 |
+
)
|
294 |
+
|
295 |
+
# Combine clip caption with parent caption
|
296 |
+
if parent_caption:
|
297 |
+
print(f"we have parent_caption, so we define the full_caption as {clip_caption}\n{parent_caption}")
|
298 |
+
|
299 |
+
full_caption = f"{clip_caption}\n{parent_caption}"
|
300 |
+
else:
|
301 |
+
print(f"we don't have a parent_caption, so we define the full_caption as {clip_caption}")
|
302 |
+
|
303 |
+
full_caption = clip_caption
|
304 |
+
|
305 |
+
if prompt_prefix:
|
306 |
+
full_caption = f"{prompt_prefix}{full_caption}"
|
307 |
+
|
308 |
+
|
309 |
+
# Write the caption file
|
310 |
+
txt_path = video_path.with_suffix('.txt')
|
311 |
+
txt_path.write_text(full_caption)
|
312 |
+
|
313 |
+
progress.status = "completed"
|
314 |
+
progress.completed_at = datetime.now()
|
315 |
+
gr.Info(f"Successfully generated caption for {video_name}")
|
316 |
+
yield progress, full_caption
|
317 |
+
|
318 |
+
except Exception as e:
|
319 |
+
progress.status = "error"
|
320 |
+
progress.error = str(e)
|
321 |
+
progress.completed_at = datetime.now()
|
322 |
+
yield progress, None
|
323 |
+
raise gr.Error(f"Error processing video: {str(e)}")
|
324 |
+
|
325 |
+
async def process_image(self, image_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]:
|
326 |
+
"""Process a single image for captioning"""
|
327 |
+
try:
|
328 |
+
image_name = image_path.name
|
329 |
+
logger.info(f"Starting processing of image: {image_name}")
|
330 |
+
|
331 |
+
progress = CaptioningProgress(
|
332 |
+
video_name=image_name, # Reusing video_name field for images
|
333 |
+
total_frames=1,
|
334 |
+
processed_frames=0,
|
335 |
+
status="initializing",
|
336 |
+
started_at=datetime.now()
|
337 |
+
)
|
338 |
+
self.active_tasks[image_name] = progress
|
339 |
+
yield progress, None
|
340 |
+
|
341 |
+
# Ensure model is loaded
|
342 |
+
await self.ensure_model_loaded()
|
343 |
+
|
344 |
+
if USE_MOCK_CAPTIONING_MODEL:
|
345 |
+
progress.status = "completed"
|
346 |
+
progress.processed_frames = 1
|
347 |
+
progress.completed_at = datetime.now()
|
348 |
+
print("yielding fake")
|
349 |
+
yield progress, "This is a test image caption"
|
350 |
+
return
|
351 |
+
|
352 |
+
# Read and process image
|
353 |
+
loop = asyncio.get_event_loop()
|
354 |
+
image = await loop.run_in_executor(
|
355 |
+
None,
|
356 |
+
lambda: cv2.imread(str(image_path))
|
357 |
+
)
|
358 |
+
if image is None:
|
359 |
+
raise ValueError(f"Could not read image: {str(image_path)}")
|
360 |
+
|
361 |
+
# Convert BGR to RGB
|
362 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
363 |
+
|
364 |
+
# Process image
|
365 |
+
processed_image = await loop.run_in_executor(
|
366 |
+
None,
|
367 |
+
lambda: self.image_processor.preprocess(
|
368 |
+
image,
|
369 |
+
return_tensors="pt"
|
370 |
+
)["pixel_values"]
|
371 |
+
)
|
372 |
+
|
373 |
+
progress.processed_frames = 1
|
374 |
+
progress.status = "generating caption"
|
375 |
+
yield progress, None
|
376 |
+
|
377 |
+
# Move to GPU and generate caption
|
378 |
+
image_tensor = processed_image.to('cuda').bfloat16()
|
379 |
+
full_prompt = f"<image>{prompt}"
|
380 |
+
|
381 |
+
input_ids = await loop.run_in_executor(
|
382 |
+
None,
|
383 |
+
lambda: tokenizer_image_token(full_prompt, self.tokenizer, return_tensors="pt").unsqueeze(0).to('cuda')
|
384 |
+
)
|
385 |
+
|
386 |
+
with torch.no_grad():
|
387 |
+
output = await loop.run_in_executor(
|
388 |
+
None,
|
389 |
+
lambda: self.model.generate(
|
390 |
+
input_ids,
|
391 |
+
images=[image_tensor],
|
392 |
+
modalities=["image"],
|
393 |
+
do_sample=False,
|
394 |
+
temperature=0,
|
395 |
+
max_new_tokens=4096,
|
396 |
+
)
|
397 |
+
)
|
398 |
+
|
399 |
+
caption = await loop.run_in_executor(
|
400 |
+
None,
|
401 |
+
lambda: self.tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
|
402 |
+
)
|
403 |
+
|
404 |
+
progress.status = "completed"
|
405 |
+
progress.completed_at = datetime.now()
|
406 |
+
gr.Info(f"Successfully generated caption for {image_name}")
|
407 |
+
yield progress, caption
|
408 |
+
|
409 |
+
except Exception as e:
|
410 |
+
progress.status = "error"
|
411 |
+
progress.error = str(e)
|
412 |
+
progress.completed_at = datetime.now()
|
413 |
+
yield progress, None
|
414 |
+
raise gr.Error(f"Error processing image: {str(e)}")
|
415 |
+
|
416 |
+
|
417 |
+
async def start_caption_generation(self, custom_prompt: str, prompt_prefix: str) -> AsyncGenerator[List[List[str]], None]:
|
418 |
+
"""Iterates over clips to auto-generate captions asynchronously."""
|
419 |
+
try:
|
420 |
+
logger.info("Starting auto-caption generation")
|
421 |
+
|
422 |
+
# Use provided prompt or default
|
423 |
+
default_prompt = DEFAULT_CAPTIONING_BOT_INSTRUCTIONS
|
424 |
+
prompt = custom_prompt.strip() or default_prompt
|
425 |
+
logger.debug(f"Using prompt: {prompt}")
|
426 |
+
|
427 |
+
# Find files needing captions
|
428 |
+
video_files = list(STAGING_PATH.glob("*.mp4"))
|
429 |
+
image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)]
|
430 |
+
all_files = video_files + image_files
|
431 |
+
|
432 |
+
# Filter for files missing captions or with empty caption files
|
433 |
+
files_to_process = []
|
434 |
+
for file_path in all_files:
|
435 |
+
caption_path = file_path.with_suffix('.txt')
|
436 |
+
needs_caption = (
|
437 |
+
not caption_path.exists() or
|
438 |
+
caption_path.stat().st_size == 0 or
|
439 |
+
caption_path.read_text().strip() == ""
|
440 |
+
)
|
441 |
+
if needs_caption:
|
442 |
+
files_to_process.append(file_path)
|
443 |
+
|
444 |
+
logger.info(f"Found {len(files_to_process)} files needing captions")
|
445 |
+
|
446 |
+
if not files_to_process:
|
447 |
+
logger.info("No files need captioning")
|
448 |
+
yield []
|
449 |
+
return
|
450 |
+
|
451 |
+
self._should_stop = False
|
452 |
+
self.active_tasks.clear()
|
453 |
+
status_update: Dict[str, Dict[str, Any]] = {}
|
454 |
+
|
455 |
+
for file_path in all_files:
|
456 |
+
if self._should_stop:
|
457 |
+
break
|
458 |
+
|
459 |
+
try:
|
460 |
+
print(f"we are in file_path {str(file_path)}")
|
461 |
+
# Choose appropriate processing method based on file type
|
462 |
+
if is_video_file(file_path):
|
463 |
+
process_gen = self.process_video(file_path, prompt, prompt_prefix)
|
464 |
+
else:
|
465 |
+
process_gen = self.process_image(file_path, prompt, prompt_prefix)
|
466 |
+
print("got process_gen = ", process_gen)
|
467 |
+
async for progress, caption in process_gen:
|
468 |
+
print(f"process_gen contains this caption = {caption}")
|
469 |
+
if caption and prompt_prefix and not caption.startswith(prompt_prefix):
|
470 |
+
caption = f"{prompt_prefix}{caption}"
|
471 |
+
|
472 |
+
# Save caption
|
473 |
+
if caption:
|
474 |
+
txt_path = file_path.with_suffix('.txt')
|
475 |
+
txt_path.write_text(caption)
|
476 |
+
|
477 |
+
logger.debug(f"Progress update: {progress.status}")
|
478 |
+
|
479 |
+
# Store progress info
|
480 |
+
status_update[file_path.name] = {
|
481 |
+
"status": progress.status,
|
482 |
+
"frames": progress.processed_frames,
|
483 |
+
"total": progress.total_frames
|
484 |
+
}
|
485 |
+
|
486 |
+
# Convert to list format for Gradio DataFrame
|
487 |
+
rows = []
|
488 |
+
for file_name, info in status_update.items():
|
489 |
+
status = info["status"]
|
490 |
+
if status == "processing":
|
491 |
+
percent = (info["frames"] / info["total"]) * 100
|
492 |
+
status = f"Analyzing... {percent:.1f}% ({info['frames']}/{info['total']} frames)"
|
493 |
+
elif status == "generating caption":
|
494 |
+
status = "Generating caption..."
|
495 |
+
elif status == "error":
|
496 |
+
status = f"Error: {progress.error}"
|
497 |
+
elif status == "completed":
|
498 |
+
status = "Completed"
|
499 |
+
|
500 |
+
rows.append([file_name, status])
|
501 |
+
|
502 |
+
yield rows
|
503 |
+
await asyncio.sleep(0.1)
|
504 |
+
|
505 |
+
|
506 |
+
except Exception as e:
|
507 |
+
logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
|
508 |
+
rows = [[str(file_path.name), f"Error: {str(e)}"]]
|
509 |
+
yield rows
|
510 |
+
continue
|
511 |
+
|
512 |
+
logger.info("Auto-caption generation completed, cyping assets to the training dir..")
|
513 |
+
|
514 |
+
copy_files_to_training_dir(prompt_prefix)
|
515 |
+
except Exception as e:
|
516 |
+
logger.error(f"Error in start_caption_generation: {str(e)}")
|
517 |
+
yield [[str(e), "error"]]
|
518 |
+
raise
|
519 |
+
|
520 |
+
def stop_captioning(self):
|
521 |
+
"""Stop all ongoing captioning tasks"""
|
522 |
+
logger.info("Stopping all captioning tasks")
|
523 |
+
self._should_stop = True
|
524 |
+
|
525 |
+
def close(self):
|
526 |
+
"""Clean up resources"""
|
527 |
+
logger.info("Cleaning up captioning service resources")
|
528 |
+
if hasattr(self, 'model'):
|
529 |
+
del self.model
|
530 |
+
if hasattr(self, 'tokenizer'):
|
531 |
+
del self.tokenizer
|
532 |
+
if hasattr(self, 'image_processor'):
|
533 |
+
del self.image_processor
|
534 |
+
torch.cuda.empty_cache()
|
config.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from typing import Dict, Any, Optional, List, Tuple
|
4 |
+
from pathlib import Path
|
5 |
+
from utils import parse_bool_env
|
6 |
+
|
7 |
+
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
8 |
+
ASK_USER_TO_DUPLICATE_SPACE = parse_bool_env(os.getenv("ASK_USER_TO_DUPLICATE_SPACE"))
|
9 |
+
|
10 |
+
# Base storage path
|
11 |
+
STORAGE_PATH = Path(os.environ.get('STORAGE_PATH', '.data'))
|
12 |
+
|
13 |
+
# Subdirectories for different data types
|
14 |
+
VIDEOS_TO_SPLIT_PATH = STORAGE_PATH / "videos_to_split" # Raw uploaded/downloaded files
|
15 |
+
STAGING_PATH = STORAGE_PATH / "staging" # This is where files that are captioned or need captioning are waiting
|
16 |
+
TRAINING_PATH = STORAGE_PATH / "training" # Folder containing the final training dataset
|
17 |
+
TRAINING_VIDEOS_PATH = TRAINING_PATH / "videos" # Captioned clips ready for training
|
18 |
+
MODEL_PATH = STORAGE_PATH / "model" # Model checkpoints and files
|
19 |
+
OUTPUT_PATH = STORAGE_PATH / "output" # Training outputs and logs
|
20 |
+
|
21 |
+
# On the production server we can afford to preload the big model
|
22 |
+
PRELOAD_CAPTIONING_MODEL = parse_bool_env(os.environ.get('PRELOAD_CAPTIONING_MODEL'))
|
23 |
+
|
24 |
+
CAPTIONING_MODEL = "lmms-lab/LLaVA-Video-7B-Qwen2"
|
25 |
+
|
26 |
+
DEFAULT_PROMPT_PREFIX = "In the style of TOK, "
|
27 |
+
|
28 |
+
# This is only use to debug things in local
|
29 |
+
USE_MOCK_CAPTIONING_MODEL = parse_bool_env(os.environ.get('USE_MOCK_CAPTIONING_MODEL'))
|
30 |
+
|
31 |
+
DEFAULT_CAPTIONING_BOT_INSTRUCTIONS = "Please write a full description of the following video: camera (close-up shot, medium-shot..), genre (music video, horror movie scene, video game footage, go pro footage, japanese anime, noir film, science-fiction, action movie, documentary..), characters (physical appearance, look, skin, facial features, haircut, clothing), scene (action, positions, movements), location (indoor, outdoor, place, building, country..), time and lighting (natural, golden hour, night time, LED lights, kelvin temperature etc), weather and climate (dusty, rainy, fog, haze, snowing..), era/settings"
|
32 |
+
|
33 |
+
# Create directories
|
34 |
+
STORAGE_PATH.mkdir(parents=True, exist_ok=True)
|
35 |
+
VIDEOS_TO_SPLIT_PATH.mkdir(parents=True, exist_ok=True)
|
36 |
+
STAGING_PATH.mkdir(parents=True, exist_ok=True)
|
37 |
+
TRAINING_PATH.mkdir(parents=True, exist_ok=True)
|
38 |
+
TRAINING_VIDEOS_PATH.mkdir(parents=True, exist_ok=True)
|
39 |
+
MODEL_PATH.mkdir(parents=True, exist_ok=True)
|
40 |
+
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
|
41 |
+
|
42 |
+
# Image normalization settings
|
43 |
+
NORMALIZE_IMAGES_TO = os.environ.get('NORMALIZE_IMAGES_TO', 'png').lower()
|
44 |
+
if NORMALIZE_IMAGES_TO not in ['png', 'jpg']:
|
45 |
+
raise ValueError("NORMALIZE_IMAGES_TO must be either 'png' or 'jpg'")
|
46 |
+
JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
|
47 |
+
|
48 |
+
MODEL_TYPES = {
|
49 |
+
"HunyuanVideo (LoRA)": "hunyuan_video",
|
50 |
+
"LTX-Video (LoRA)": "ltx_video"
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
# it is best to use resolutions that are powers of 8
|
55 |
+
# The resolution should be divisible by 32
|
56 |
+
# so we cannot use 1080, 540 etc as they are not divisible by 32
|
57 |
+
TRAINING_WIDTH = 768 # 32 * 24
|
58 |
+
TRAINING_HEIGHT = 512 # 32 * 16
|
59 |
+
|
60 |
+
# 1920 = 32 * 60 (divided by 2: 960 = 32 * 30)
|
61 |
+
# 1920 = 32 * 60 (divided by 2: 960 = 32 * 30)
|
62 |
+
# 1056 = 32 * 33 (divided by 2: 544 = 17 * 32)
|
63 |
+
# 1024 = 32 * 32 (divided by 2: 512 = 16 * 32)
|
64 |
+
# it is important that the resolution buckets properly cover the training dataset,
|
65 |
+
# or else that we exclude from the dataset videos that are out of this range
|
66 |
+
# right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
|
67 |
+
|
68 |
+
TRAINING_BUCKETS = [
|
69 |
+
(8 * 2 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 16 + 1
|
70 |
+
(8 * 4 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 32 + 1
|
71 |
+
(8 * 6 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 48 + 1
|
72 |
+
(8 * 8 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 64 + 1
|
73 |
+
(8 * 10 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 80 + 1
|
74 |
+
(8 * 12 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 96 + 1
|
75 |
+
(8 * 14 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 112 + 1
|
76 |
+
(8 * 16 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 128 + 1
|
77 |
+
(8 * 18 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 144 + 1
|
78 |
+
(8 * 20 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 160 + 1
|
79 |
+
(8 * 22 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 176 + 1
|
80 |
+
(8 * 24 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 192 + 1
|
81 |
+
(8 * 28 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 224 + 1
|
82 |
+
(8 * 32 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 256 + 1
|
83 |
+
]
|
84 |
+
|
85 |
+
@dataclass
|
86 |
+
class TrainingConfig:
|
87 |
+
"""Configuration class for finetrainers training"""
|
88 |
+
|
89 |
+
# Required arguments must come first
|
90 |
+
model_name: str
|
91 |
+
pretrained_model_name_or_path: str
|
92 |
+
data_root: str
|
93 |
+
output_dir: str
|
94 |
+
|
95 |
+
# Optional arguments follow
|
96 |
+
revision: Optional[str] = None
|
97 |
+
variant: Optional[str] = None
|
98 |
+
cache_dir: Optional[str] = None
|
99 |
+
|
100 |
+
# Dataset arguments
|
101 |
+
|
102 |
+
# note: video_column and caption_column serve a dual purpose,
|
103 |
+
# when using the CSV mode they have to be CSV column names,
|
104 |
+
# otherwise they have to be filename (relative to the data_root dir path)
|
105 |
+
video_column: str = "videos.txt"
|
106 |
+
caption_column: str = "prompts.txt"
|
107 |
+
|
108 |
+
id_token: Optional[str] = None
|
109 |
+
video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: TRAINING_BUCKETS)
|
110 |
+
video_reshape_mode: str = "center"
|
111 |
+
caption_dropout_p: float = 0.05
|
112 |
+
caption_dropout_technique: str = "empty"
|
113 |
+
precompute_conditions: bool = False
|
114 |
+
|
115 |
+
# Diffusion arguments
|
116 |
+
flow_resolution_shifting: bool = False
|
117 |
+
flow_weighting_scheme: str = "none"
|
118 |
+
flow_logit_mean: float = 0.0
|
119 |
+
flow_logit_std: float = 1.0
|
120 |
+
flow_mode_scale: float = 1.29
|
121 |
+
|
122 |
+
# Training arguments
|
123 |
+
training_type: str = "lora"
|
124 |
+
seed: int = 42
|
125 |
+
mixed_precision: str = "bf16"
|
126 |
+
batch_size: int = 1
|
127 |
+
train_epochs: int = 70
|
128 |
+
lora_rank: int = 128
|
129 |
+
lora_alpha: int = 128
|
130 |
+
target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"])
|
131 |
+
gradient_accumulation_steps: int = 1
|
132 |
+
gradient_checkpointing: bool = True
|
133 |
+
checkpointing_steps: int = 500
|
134 |
+
checkpointing_limit: Optional[int] = 2
|
135 |
+
resume_from_checkpoint: Optional[str] = None
|
136 |
+
enable_slicing: bool = True
|
137 |
+
enable_tiling: bool = True
|
138 |
+
|
139 |
+
# Optimizer arguments
|
140 |
+
optimizer: str = "adamw"
|
141 |
+
lr: float = 3e-5
|
142 |
+
scale_lr: bool = False
|
143 |
+
lr_scheduler: str = "constant_with_warmup"
|
144 |
+
lr_warmup_steps: int = 100
|
145 |
+
lr_num_cycles: int = 1
|
146 |
+
lr_power: float = 1.0
|
147 |
+
beta1: float = 0.9
|
148 |
+
beta2: float = 0.95
|
149 |
+
weight_decay: float = 1e-4
|
150 |
+
epsilon: float = 1e-8
|
151 |
+
max_grad_norm: float = 1.0
|
152 |
+
|
153 |
+
# Miscellaneous arguments
|
154 |
+
tracker_name: str = "finetrainers"
|
155 |
+
report_to: str = "wandb"
|
156 |
+
nccl_timeout: int = 1800
|
157 |
+
|
158 |
+
@classmethod
|
159 |
+
def hunyuan_video_lora(cls, data_path: str, output_path: str) -> 'TrainingConfig':
|
160 |
+
"""Configuration for Hunyuan video-to-video LoRA training"""
|
161 |
+
return cls(
|
162 |
+
model_name="hunyuan_video",
|
163 |
+
pretrained_model_name_or_path="hunyuanvideo-community/HunyuanVideo",
|
164 |
+
data_root=data_path,
|
165 |
+
output_dir=output_path,
|
166 |
+
batch_size=1,
|
167 |
+
train_epochs=70,
|
168 |
+
lr=2e-5,
|
169 |
+
gradient_checkpointing=True,
|
170 |
+
id_token="afkx",
|
171 |
+
gradient_accumulation_steps=1,
|
172 |
+
lora_rank=128,
|
173 |
+
lora_alpha=128,
|
174 |
+
video_resolution_buckets=TRAINING_BUCKETS,
|
175 |
+
caption_dropout_p=0.05,
|
176 |
+
flow_weighting_scheme="none" # Hunyuan specific
|
177 |
+
)
|
178 |
+
|
179 |
+
@classmethod
|
180 |
+
def ltx_video_lora(cls, data_path: str, output_path: str) -> 'TrainingConfig':
|
181 |
+
"""Configuration for LTX-Video LoRA training"""
|
182 |
+
return cls(
|
183 |
+
model_name="ltx_video",
|
184 |
+
pretrained_model_name_or_path="Lightricks/LTX-Video",
|
185 |
+
data_root=data_path,
|
186 |
+
output_dir=output_path,
|
187 |
+
batch_size=1,
|
188 |
+
train_epochs=70,
|
189 |
+
lr=3e-5,
|
190 |
+
gradient_checkpointing=True,
|
191 |
+
id_token="BW_STYLE",
|
192 |
+
gradient_accumulation_steps=4,
|
193 |
+
lora_rank=128,
|
194 |
+
lora_alpha=128,
|
195 |
+
video_resolution_buckets=TRAINING_BUCKETS,
|
196 |
+
caption_dropout_p=0.05,
|
197 |
+
flow_weighting_scheme="logit_normal" # LTX specific
|
198 |
+
)
|
199 |
+
|
200 |
+
def to_args_list(self) -> List[str]:
|
201 |
+
"""Convert config to command line arguments list"""
|
202 |
+
args = []
|
203 |
+
|
204 |
+
# Model arguments
|
205 |
+
|
206 |
+
# Add model_name (required argument)
|
207 |
+
args.extend(["--model_name", self.model_name])
|
208 |
+
|
209 |
+
args.extend(["--pretrained_model_name_or_path", self.pretrained_model_name_or_path])
|
210 |
+
if self.revision:
|
211 |
+
args.extend(["--revision", self.revision])
|
212 |
+
if self.variant:
|
213 |
+
args.extend(["--variant", self.variant])
|
214 |
+
if self.cache_dir:
|
215 |
+
args.extend(["--cache_dir", self.cache_dir])
|
216 |
+
|
217 |
+
# Dataset arguments
|
218 |
+
args.extend(["--data_root", self.data_root])
|
219 |
+
args.extend(["--video_column", self.video_column])
|
220 |
+
args.extend(["--caption_column", self.caption_column])
|
221 |
+
if self.id_token:
|
222 |
+
args.extend(["--id_token", self.id_token])
|
223 |
+
|
224 |
+
# Add video resolution buckets
|
225 |
+
if self.video_resolution_buckets:
|
226 |
+
bucket_strs = [f"{f}x{h}x{w}" for f, h, w in self.video_resolution_buckets]
|
227 |
+
args.extend(["--video_resolution_buckets"] + bucket_strs)
|
228 |
+
|
229 |
+
if self.video_reshape_mode:
|
230 |
+
args.extend(["--video_reshape_mode", self.video_reshape_mode])
|
231 |
+
|
232 |
+
args.extend(["--caption_dropout_p", str(self.caption_dropout_p)])
|
233 |
+
args.extend(["--caption_dropout_technique", self.caption_dropout_technique])
|
234 |
+
if self.precompute_conditions:
|
235 |
+
args.append("--precompute_conditions")
|
236 |
+
|
237 |
+
# Diffusion arguments
|
238 |
+
if self.flow_resolution_shifting:
|
239 |
+
args.append("--flow_resolution_shifting")
|
240 |
+
args.extend(["--flow_weighting_scheme", self.flow_weighting_scheme])
|
241 |
+
args.extend(["--flow_logit_mean", str(self.flow_logit_mean)])
|
242 |
+
args.extend(["--flow_logit_std", str(self.flow_logit_std)])
|
243 |
+
args.extend(["--flow_mode_scale", str(self.flow_mode_scale)])
|
244 |
+
|
245 |
+
# Training arguments
|
246 |
+
args.extend(["--training_type", self.training_type])
|
247 |
+
args.extend(["--seed", str(self.seed)])
|
248 |
+
|
249 |
+
# we don't use this, because mixed precision is handled by accelerate launch, not by the training script itself.
|
250 |
+
#args.extend(["--mixed_precision", self.mixed_precision])
|
251 |
+
|
252 |
+
args.extend(["--batch_size", str(self.batch_size)])
|
253 |
+
args.extend(["--train_epochs", str(self.train_epochs)])
|
254 |
+
args.extend(["--rank", str(self.lora_rank)])
|
255 |
+
args.extend(["--lora_alpha", str(self.lora_alpha)])
|
256 |
+
args.extend(["--target_modules"] + self.target_modules)
|
257 |
+
args.extend(["--gradient_accumulation_steps", str(self.gradient_accumulation_steps)])
|
258 |
+
if self.gradient_checkpointing:
|
259 |
+
args.append("--gradient_checkpointing")
|
260 |
+
args.extend(["--checkpointing_steps", str(self.checkpointing_steps)])
|
261 |
+
if self.checkpointing_limit:
|
262 |
+
args.extend(["--checkpointing_limit", str(self.checkpointing_limit)])
|
263 |
+
if self.resume_from_checkpoint:
|
264 |
+
args.extend(["--resume_from_checkpoint", self.resume_from_checkpoint])
|
265 |
+
if self.enable_slicing:
|
266 |
+
args.append("--enable_slicing")
|
267 |
+
if self.enable_tiling:
|
268 |
+
args.append("--enable_tiling")
|
269 |
+
|
270 |
+
# Optimizer arguments
|
271 |
+
args.extend(["--optimizer", self.optimizer])
|
272 |
+
args.extend(["--lr", str(self.lr)])
|
273 |
+
if self.scale_lr:
|
274 |
+
args.append("--scale_lr")
|
275 |
+
args.extend(["--lr_scheduler", self.lr_scheduler])
|
276 |
+
args.extend(["--lr_warmup_steps", str(self.lr_warmup_steps)])
|
277 |
+
args.extend(["--lr_num_cycles", str(self.lr_num_cycles)])
|
278 |
+
args.extend(["--lr_power", str(self.lr_power)])
|
279 |
+
args.extend(["--beta1", str(self.beta1)])
|
280 |
+
args.extend(["--beta2", str(self.beta2)])
|
281 |
+
args.extend(["--weight_decay", str(self.weight_decay)])
|
282 |
+
args.extend(["--epsilon", str(self.epsilon)])
|
283 |
+
args.extend(["--max_grad_norm", str(self.max_grad_norm)])
|
284 |
+
|
285 |
+
# Miscellaneous arguments
|
286 |
+
args.extend(["--tracker_name", self.tracker_name])
|
287 |
+
args.extend(["--output_dir", self.output_dir])
|
288 |
+
args.extend(["--report_to", self.report_to])
|
289 |
+
args.extend(["--nccl_timeout", str(self.nccl_timeout)])
|
290 |
+
|
291 |
+
# normally this is disabled by default, but there was a bug in finetrainers
|
292 |
+
# so I had to fix it in trainer.py to make sure we check for push_to-hub
|
293 |
+
#args.append("--push_to_hub")
|
294 |
+
#args.extend(["--hub_token", str(False)])
|
295 |
+
#args.extend(["--hub_model_id", str(False)])
|
296 |
+
|
297 |
+
# If you are using LLM-captioned videos, it is common to see many unwanted starting phrases like
|
298 |
+
# "In this video, ...", "This video features ...", etc.
|
299 |
+
# To remove a simple subset of these phrases, you can specify
|
300 |
+
# --remove_common_llm_caption_prefixes when starting training.
|
301 |
+
args.append("--remove_common_llm_caption_prefixes")
|
302 |
+
|
303 |
+
return args
|
finetrainers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .args import Args, parse_arguments
|
2 |
+
from .trainer import Trainer
|
finetrainers/args.py
ADDED
@@ -0,0 +1,1191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import sys
|
3 |
+
from typing import Any, Dict, List, Optional, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS
|
8 |
+
from .models import SUPPORTED_MODEL_CONFIGS
|
9 |
+
|
10 |
+
|
11 |
+
class Args:
|
12 |
+
r"""
|
13 |
+
The arguments for the finetrainers training script.
|
14 |
+
|
15 |
+
For helpful information about arguments, run `python train.py --help`.
|
16 |
+
|
17 |
+
TODO(aryan): add `python train.py --recommend_configs --model_name <model_name>` to recommend
|
18 |
+
good training configs for a model after extensive testing.
|
19 |
+
TODO(aryan): add `python train.py --memory_requirements --model_name <model_name>` to show
|
20 |
+
memory requirements per model, per training type with sensible training settings.
|
21 |
+
|
22 |
+
MODEL ARGUMENTS
|
23 |
+
---------------
|
24 |
+
model_name (`str`):
|
25 |
+
Name of model to train. To get a list of models, run `python train.py --list_models`.
|
26 |
+
pretrained_model_name_or_path (`str`):
|
27 |
+
Path to pretrained model or model identifier from https://huggingface.co/models. The model should be
|
28 |
+
loadable based on specified `model_name`.
|
29 |
+
revision (`str`, defaults to `None`):
|
30 |
+
If provided, the model will be loaded from a specific branch of the model repository.
|
31 |
+
variant (`str`, defaults to `None`):
|
32 |
+
Variant of model weights to use. Some models provide weight variants, such as `fp16`, to reduce disk
|
33 |
+
storage requirements.
|
34 |
+
cache_dir (`str`, defaults to `None`):
|
35 |
+
The directory where the downloaded models and datasets will be stored, or loaded from.
|
36 |
+
text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
|
37 |
+
Data type for the text encoder when generating text embeddings.
|
38 |
+
text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
|
39 |
+
Data type for the text encoder 2 when generating text embeddings.
|
40 |
+
text_encoder_3_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
|
41 |
+
Data type for the text encoder 3 when generating text embeddings.
|
42 |
+
transformer_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
|
43 |
+
Data type for the transformer model.
|
44 |
+
vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
|
45 |
+
Data type for the VAE model.
|
46 |
+
layerwise_upcasting_modules (`List[str]`, defaults to `[]`):
|
47 |
+
Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer'].
|
48 |
+
layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`):
|
49 |
+
Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2'].
|
50 |
+
layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`):
|
51 |
+
Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision
|
52 |
+
naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
|
53 |
+
by default, and recommend adding more layers to the default list based on the model architecture.
|
54 |
+
|
55 |
+
DATASET ARGUMENTS
|
56 |
+
-----------------
|
57 |
+
data_root (`str`):
|
58 |
+
A folder containing the training data.
|
59 |
+
dataset_file (`str`, defaults to `None`):
|
60 |
+
Path to a CSV/JSON/JSONL file containing metadata for training. This should be provided if you're not using
|
61 |
+
a directory dataset format containing a simple `prompts.txt` and `videos.txt`/`images.txt` for example.
|
62 |
+
video_column (`str`):
|
63 |
+
The column of the dataset containing videos. Or, the name of the file in `data_root` folder containing the
|
64 |
+
line-separated path to video data.
|
65 |
+
caption_column (`str`):
|
66 |
+
The column of the dataset containing the instance prompt for each video. Or, the name of the file in
|
67 |
+
`data_root` folder containing the line-separated instance prompts.
|
68 |
+
id_token (`str`, defaults to `None`):
|
69 |
+
Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training.
|
70 |
+
image_resolution_buckets (`List[Tuple[int, int]]`, defaults to `None`):
|
71 |
+
Resolution buckets for images. This should be a list of integer tuples, where each tuple represents the
|
72 |
+
resolution (height, width) of the image. All images will be resized to the nearest bucket resolution.
|
73 |
+
video_resolution_buckets (`List[Tuple[int, int, int]]`, defaults to `None`):
|
74 |
+
Resolution buckets for videos. This should be a list of integer tuples, where each tuple represents the
|
75 |
+
resolution (num_frames, height, width) of the video. All videos will be resized to the nearest bucket
|
76 |
+
resolution.
|
77 |
+
video_reshape_mode (`str`, defaults to `None`):
|
78 |
+
All input videos are reshaped to this mode. Choose between ['center', 'random', 'none'].
|
79 |
+
TODO(aryan): We don't support this.
|
80 |
+
caption_dropout_p (`float`, defaults to `0.00`):
|
81 |
+
Probability of dropout for the caption tokens. This is useful to improve the unconditional generation
|
82 |
+
quality of the model.
|
83 |
+
caption_dropout_technique (`str`, defaults to `empty`):
|
84 |
+
Technique to use for caption dropout. Choose between ['empty', 'zero']. Some models apply caption dropout
|
85 |
+
by setting the prompt condition to an empty string, while others zero-out the text embedding tensors.
|
86 |
+
precompute_conditions (`bool`, defaults to `False`):
|
87 |
+
Whether or not to precompute the conditionings for the model. This is useful for faster training, and
|
88 |
+
reduces the memory requirements.
|
89 |
+
remove_common_llm_caption_prefixes (`bool`, defaults to `False`):
|
90 |
+
Whether or not to remove common LLM caption prefixes. This is useful for improving the quality of the
|
91 |
+
generated text.
|
92 |
+
|
93 |
+
DATALOADER_ARGUMENTS
|
94 |
+
--------------------
|
95 |
+
See https://pytorch.org/docs/stable/data.html for more information.
|
96 |
+
|
97 |
+
dataloader_num_workers (`int`, defaults to `0`):
|
98 |
+
Number of subprocesses to use for data loading. `0` means that the data will be loaded in a blocking manner
|
99 |
+
on the main process.
|
100 |
+
pin_memory (`bool`, defaults to `False`):
|
101 |
+
Whether or not to use the pinned memory setting in PyTorch dataloader. This is useful for faster data loading.
|
102 |
+
|
103 |
+
DIFFUSION ARGUMENTS
|
104 |
+
-------------------
|
105 |
+
flow_resolution_shifting (`bool`, defaults to `False`):
|
106 |
+
Resolution-dependent shifting of timestep schedules.
|
107 |
+
[Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206).
|
108 |
+
TODO(aryan): We don't support this yet.
|
109 |
+
flow_base_seq_len (`int`, defaults to `256`):
|
110 |
+
Base number of tokens for images/video when applying resolution-dependent shifting.
|
111 |
+
flow_max_seq_len (`int`, defaults to `4096`):
|
112 |
+
Maximum number of tokens for images/video when applying resolution-dependent shifting.
|
113 |
+
flow_base_shift (`float`, defaults to `0.5`):
|
114 |
+
Base shift for timestep schedules when applying resolution-dependent shifting.
|
115 |
+
flow_max_shift (`float`, defaults to `1.15`):
|
116 |
+
Maximum shift for timestep schedules when applying resolution-dependent shifting.
|
117 |
+
flow_shift (`float`, defaults to `1.0`):
|
118 |
+
Instead of training with uniform/logit-normal sigmas, shift them as (shift * sigma) / (1 + (shift - 1) * sigma).
|
119 |
+
Setting it higher is helpful when trying to train models for high-resolution generation or to produce better
|
120 |
+
samples in lower number of inference steps.
|
121 |
+
flow_weighting_scheme (`str`, defaults to `none`):
|
122 |
+
We default to the "none" weighting scheme for uniform sampling and uniform loss.
|
123 |
+
Choose between ['sigma_sqrt', 'logit_normal', 'mode', 'cosmap', 'none'].
|
124 |
+
flow_logit_mean (`float`, defaults to `0.0`):
|
125 |
+
Mean to use when using the `'logit_normal'` weighting scheme.
|
126 |
+
flow_logit_std (`float`, defaults to `1.0`):
|
127 |
+
Standard deviation to use when using the `'logit_normal'` weighting scheme.
|
128 |
+
flow_mode_scale (`float`, defaults to `1.29`):
|
129 |
+
Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.
|
130 |
+
|
131 |
+
TRAINING ARGUMENTS
|
132 |
+
------------------
|
133 |
+
training_type (`str`, defaults to `None`):
|
134 |
+
Type of training to perform. Choose between ['lora'].
|
135 |
+
seed (`int`, defaults to `42`):
|
136 |
+
A seed for reproducible training.
|
137 |
+
batch_size (`int`, defaults to `1`):
|
138 |
+
Per-device batch size.
|
139 |
+
train_epochs (`int`, defaults to `1`):
|
140 |
+
Number of training epochs.
|
141 |
+
train_steps (`int`, defaults to `None`):
|
142 |
+
Total number of training steps to perform. If provided, overrides `train_epochs`.
|
143 |
+
rank (`int`, defaults to `128`):
|
144 |
+
The rank for LoRA matrices.
|
145 |
+
lora_alpha (`float`, defaults to `64`):
|
146 |
+
The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.
|
147 |
+
target_modules (`List[str]`, defaults to `["to_k", "to_q", "to_v", "to_out.0"]`):
|
148 |
+
The target modules for LoRA. Make sure to modify this based on the model.
|
149 |
+
gradient_accumulation_steps (`int`, defaults to `1`):
|
150 |
+
Number of gradients steps to accumulate before performing an optimizer step.
|
151 |
+
gradient_checkpointing (`bool`, defaults to `False`):
|
152 |
+
Whether or not to use gradient/activation checkpointing to save memory at the expense of slower
|
153 |
+
backward pass.
|
154 |
+
checkpointing_steps (`int`, defaults to `500`):
|
155 |
+
Save a checkpoint of the training state every X training steps. These checkpoints can be used both
|
156 |
+
as final checkpoints in case they are better than the last checkpoint, and are also suitable for
|
157 |
+
resuming training using `resume_from_checkpoint`.
|
158 |
+
checkpointing_limit (`int`, defaults to `None`):
|
159 |
+
Max number of checkpoints to store.
|
160 |
+
resume_from_checkpoint (`str`, defaults to `None`):
|
161 |
+
Whether training should be resumed from a previous checkpoint. Use a path saved by `checkpointing_steps`,
|
162 |
+
or `"latest"` to automatically select the last available checkpoint.
|
163 |
+
|
164 |
+
OPTIMIZER ARGUMENTS
|
165 |
+
-------------------
|
166 |
+
optimizer (`str`, defaults to `adamw`):
|
167 |
+
The optimizer type to use. Choose between ['adam', 'adamw'].
|
168 |
+
use_8bit_bnb (`bool`, defaults to `False`):
|
169 |
+
Whether to use 8bit variant of the `optimizer` using `bitsandbytes`.
|
170 |
+
lr (`float`, defaults to `1e-4`):
|
171 |
+
Initial learning rate (after the potential warmup period) to use.
|
172 |
+
scale_lr (`bool`, defaults to `False`):
|
173 |
+
Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.
|
174 |
+
lr_scheduler (`str`, defaults to `cosine_with_restarts`):
|
175 |
+
The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',
|
176 |
+
'constant', 'constant_with_warmup'].
|
177 |
+
lr_warmup_steps (`int`, defaults to `500`):
|
178 |
+
Number of steps for the warmup in the lr scheduler.
|
179 |
+
lr_num_cycles (`int`, defaults to `1`):
|
180 |
+
Number of hard resets of the lr in cosine_with_restarts scheduler.
|
181 |
+
lr_power (`float`, defaults to `1.0`):
|
182 |
+
Power factor of the polynomial scheduler.
|
183 |
+
beta1 (`float`, defaults to `0.9`):
|
184 |
+
beta2 (`float`, defaults to `0.95`):
|
185 |
+
beta3 (`float`, defaults to `0.999`):
|
186 |
+
weight_decay (`float`, defaults to `0.0001`):
|
187 |
+
Penalty for large weights in the model.
|
188 |
+
epsilon (`float`, defaults to `1e-8`):
|
189 |
+
Small value to avoid division by zero in the optimizer.
|
190 |
+
max_grad_norm (`float`, defaults to `1.0`):
|
191 |
+
Maximum gradient norm to clip the gradients.
|
192 |
+
|
193 |
+
VALIDATION ARGUMENTS
|
194 |
+
--------------------
|
195 |
+
validation_prompts (`List[str]`, defaults to `None`):
|
196 |
+
List of prompts to use for validation. If not provided, a random prompt will be selected from the training
|
197 |
+
dataset.
|
198 |
+
validation_images (`List[str]`, defaults to `None`):
|
199 |
+
List of image paths to use for validation.
|
200 |
+
validation_videos (`List[str]`, defaults to `None`):
|
201 |
+
List of video paths to use for validation.
|
202 |
+
validation_heights (`List[int]`, defaults to `None`):
|
203 |
+
List of heights for the validation videos.
|
204 |
+
validation_widths (`List[int]`, defaults to `None`):
|
205 |
+
List of widths for the validation videos.
|
206 |
+
validation_num_frames (`List[int]`, defaults to `None`):
|
207 |
+
List of number of frames for the validation videos.
|
208 |
+
num_validation_videos_per_prompt (`int`, defaults to `1`):
|
209 |
+
Number of videos to use for validation per prompt.
|
210 |
+
validation_every_n_epochs (`int`, defaults to `None`):
|
211 |
+
Perform validation every `n` training epochs.
|
212 |
+
validation_every_n_steps (`int`, defaults to `None`):
|
213 |
+
Perform validation every `n` training steps.
|
214 |
+
enable_model_cpu_offload (`bool`, defaults to `False`):
|
215 |
+
Whether or not to offload different modeling components to CPU during validation.
|
216 |
+
validation_frame_rate (`int`, defaults to `25`):
|
217 |
+
Frame rate to use for the validation videos. This value is defaulted to 25, as used in LTX Video pipeline.
|
218 |
+
|
219 |
+
MISCELLANEOUS ARGUMENTS
|
220 |
+
-----------------------
|
221 |
+
tracker_name (`str`, defaults to `finetrainers`):
|
222 |
+
Name of the tracker/project to use for logging training metrics.
|
223 |
+
push_to_hub (`bool`, defaults to `False`):
|
224 |
+
Whether or not to push the model to the Hugging Face Hub.
|
225 |
+
hub_token (`str`, defaults to `None`):
|
226 |
+
The API token to use for pushing the model to the Hugging Face Hub.
|
227 |
+
hub_model_id (`str`, defaults to `None`):
|
228 |
+
The model identifier to use for pushing the model to the Hugging Face Hub.
|
229 |
+
output_dir (`str`, defaults to `None`):
|
230 |
+
The directory where the model checkpoints and logs will be stored.
|
231 |
+
logging_dir (`str`, defaults to `logs`):
|
232 |
+
The directory where the logs will be stored.
|
233 |
+
allow_tf32 (`bool`, defaults to `False`):
|
234 |
+
Whether or not to allow the use of TF32 matmul on compatible hardware.
|
235 |
+
nccl_timeout (`int`, defaults to `1800`):
|
236 |
+
Timeout for the NCCL communication.
|
237 |
+
report_to (`str`, defaults to `wandb`):
|
238 |
+
The name of the logger to use for logging training metrics. Choose between ['wandb'].
|
239 |
+
"""
|
240 |
+
|
241 |
+
# Model arguments
|
242 |
+
model_name: str = None
|
243 |
+
pretrained_model_name_or_path: str = None
|
244 |
+
revision: Optional[str] = None
|
245 |
+
variant: Optional[str] = None
|
246 |
+
cache_dir: Optional[str] = None
|
247 |
+
text_encoder_dtype: torch.dtype = torch.bfloat16
|
248 |
+
text_encoder_2_dtype: torch.dtype = torch.bfloat16
|
249 |
+
text_encoder_3_dtype: torch.dtype = torch.bfloat16
|
250 |
+
transformer_dtype: torch.dtype = torch.bfloat16
|
251 |
+
vae_dtype: torch.dtype = torch.bfloat16
|
252 |
+
layerwise_upcasting_modules: List[str] = []
|
253 |
+
layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn
|
254 |
+
layerwise_upcasting_skip_modules_pattern: List[str] = [
|
255 |
+
"patch_embed",
|
256 |
+
"pos_embed",
|
257 |
+
"x_embedder",
|
258 |
+
"context_embedder",
|
259 |
+
"time_embed",
|
260 |
+
"^proj_in$",
|
261 |
+
"^proj_out$",
|
262 |
+
"norm",
|
263 |
+
]
|
264 |
+
|
265 |
+
# Dataset arguments
|
266 |
+
data_root: str = None
|
267 |
+
dataset_file: Optional[str] = None
|
268 |
+
video_column: str = None
|
269 |
+
caption_column: str = None
|
270 |
+
id_token: Optional[str] = None
|
271 |
+
image_resolution_buckets: List[Tuple[int, int]] = None
|
272 |
+
video_resolution_buckets: List[Tuple[int, int, int]] = None
|
273 |
+
video_reshape_mode: Optional[str] = None
|
274 |
+
caption_dropout_p: float = 0.00
|
275 |
+
caption_dropout_technique: str = "empty"
|
276 |
+
precompute_conditions: bool = False
|
277 |
+
remove_common_llm_caption_prefixes: bool = False
|
278 |
+
|
279 |
+
# Dataloader arguments
|
280 |
+
dataloader_num_workers: int = 0
|
281 |
+
pin_memory: bool = False
|
282 |
+
|
283 |
+
# Diffusion arguments
|
284 |
+
flow_resolution_shifting: bool = False
|
285 |
+
flow_base_seq_len: int = 256
|
286 |
+
flow_max_seq_len: int = 4096
|
287 |
+
flow_base_shift: float = 0.5
|
288 |
+
flow_max_shift: float = 1.15
|
289 |
+
flow_shift: float = 1.0
|
290 |
+
flow_weighting_scheme: str = "none"
|
291 |
+
flow_logit_mean: float = 0.0
|
292 |
+
flow_logit_std: float = 1.0
|
293 |
+
flow_mode_scale: float = 1.29
|
294 |
+
|
295 |
+
# Training arguments
|
296 |
+
training_type: str = None
|
297 |
+
seed: int = 42
|
298 |
+
batch_size: int = 1
|
299 |
+
train_epochs: int = 1
|
300 |
+
train_steps: int = None
|
301 |
+
rank: int = 128
|
302 |
+
lora_alpha: float = 64
|
303 |
+
target_modules: List[str] = ["to_k", "to_q", "to_v", "to_out.0"]
|
304 |
+
gradient_accumulation_steps: int = 1
|
305 |
+
gradient_checkpointing: bool = False
|
306 |
+
checkpointing_steps: int = 500
|
307 |
+
checkpointing_limit: Optional[int] = None
|
308 |
+
resume_from_checkpoint: Optional[str] = None
|
309 |
+
enable_slicing: bool = False
|
310 |
+
enable_tiling: bool = False
|
311 |
+
|
312 |
+
# Optimizer arguments
|
313 |
+
optimizer: str = "adamw"
|
314 |
+
use_8bit_bnb: bool = False
|
315 |
+
lr: float = 1e-4
|
316 |
+
scale_lr: bool = False
|
317 |
+
lr_scheduler: str = "cosine_with_restarts"
|
318 |
+
lr_warmup_steps: int = 0
|
319 |
+
lr_num_cycles: int = 1
|
320 |
+
lr_power: float = 1.0
|
321 |
+
beta1: float = 0.9
|
322 |
+
beta2: float = 0.95
|
323 |
+
beta3: float = 0.999
|
324 |
+
weight_decay: float = 0.0001
|
325 |
+
epsilon: float = 1e-8
|
326 |
+
max_grad_norm: float = 1.0
|
327 |
+
|
328 |
+
# Validation arguments
|
329 |
+
validation_prompts: List[str] = None
|
330 |
+
validation_images: List[str] = None
|
331 |
+
validation_videos: List[str] = None
|
332 |
+
validation_heights: List[int] = None
|
333 |
+
validation_widths: List[int] = None
|
334 |
+
validation_num_frames: List[int] = None
|
335 |
+
num_validation_videos_per_prompt: int = 1
|
336 |
+
validation_every_n_epochs: Optional[int] = None
|
337 |
+
validation_every_n_steps: Optional[int] = None
|
338 |
+
enable_model_cpu_offload: bool = False
|
339 |
+
validation_frame_rate: int = 25
|
340 |
+
|
341 |
+
# Miscellaneous arguments
|
342 |
+
tracker_name: str = "finetrainers"
|
343 |
+
push_to_hub: bool = False
|
344 |
+
hub_token: Optional[str] = None
|
345 |
+
hub_model_id: Optional[str] = None
|
346 |
+
output_dir: str = None
|
347 |
+
logging_dir: Optional[str] = "logs"
|
348 |
+
allow_tf32: bool = False
|
349 |
+
nccl_timeout: int = 1800 # 30 minutes
|
350 |
+
report_to: str = "wandb"
|
351 |
+
|
352 |
+
def to_dict(self) -> Dict[str, Any]:
|
353 |
+
return {
|
354 |
+
"model_arguments": {
|
355 |
+
"model_name": self.model_name,
|
356 |
+
"pretrained_model_name_or_path": self.pretrained_model_name_or_path,
|
357 |
+
"revision": self.revision,
|
358 |
+
"variant": self.variant,
|
359 |
+
"cache_dir": self.cache_dir,
|
360 |
+
"text_encoder_dtype": self.text_encoder_dtype,
|
361 |
+
"text_encoder_2_dtype": self.text_encoder_2_dtype,
|
362 |
+
"text_encoder_3_dtype": self.text_encoder_3_dtype,
|
363 |
+
"transformer_dtype": self.transformer_dtype,
|
364 |
+
"vae_dtype": self.vae_dtype,
|
365 |
+
"layerwise_upcasting_modules": self.layerwise_upcasting_modules,
|
366 |
+
"layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
|
367 |
+
"layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
|
368 |
+
},
|
369 |
+
"dataset_arguments": {
|
370 |
+
"data_root": self.data_root,
|
371 |
+
"dataset_file": self.dataset_file,
|
372 |
+
"video_column": self.video_column,
|
373 |
+
"caption_column": self.caption_column,
|
374 |
+
"id_token": self.id_token,
|
375 |
+
"image_resolution_buckets": self.image_resolution_buckets,
|
376 |
+
"video_resolution_buckets": self.video_resolution_buckets,
|
377 |
+
"video_reshape_mode": self.video_reshape_mode,
|
378 |
+
"caption_dropout_p": self.caption_dropout_p,
|
379 |
+
"caption_dropout_technique": self.caption_dropout_technique,
|
380 |
+
"precompute_conditions": self.precompute_conditions,
|
381 |
+
"remove_common_llm_caption_prefixes": self.remove_common_llm_caption_prefixes,
|
382 |
+
},
|
383 |
+
"dataloader_arguments": {
|
384 |
+
"dataloader_num_workers": self.dataloader_num_workers,
|
385 |
+
"pin_memory": self.pin_memory,
|
386 |
+
},
|
387 |
+
"diffusion_arguments": {
|
388 |
+
"flow_resolution_shifting": self.flow_resolution_shifting,
|
389 |
+
"flow_base_seq_len": self.flow_base_seq_len,
|
390 |
+
"flow_max_seq_len": self.flow_max_seq_len,
|
391 |
+
"flow_base_shift": self.flow_base_shift,
|
392 |
+
"flow_max_shift": self.flow_max_shift,
|
393 |
+
"flow_shift": self.flow_shift,
|
394 |
+
"flow_weighting_scheme": self.flow_weighting_scheme,
|
395 |
+
"flow_logit_mean": self.flow_logit_mean,
|
396 |
+
"flow_logit_std": self.flow_logit_std,
|
397 |
+
"flow_mode_scale": self.flow_mode_scale,
|
398 |
+
},
|
399 |
+
"training_arguments": {
|
400 |
+
"training_type": self.training_type,
|
401 |
+
"seed": self.seed,
|
402 |
+
"batch_size": self.batch_size,
|
403 |
+
"train_epochs": self.train_epochs,
|
404 |
+
"train_steps": self.train_steps,
|
405 |
+
"rank": self.rank,
|
406 |
+
"lora_alpha": self.lora_alpha,
|
407 |
+
"target_modules": self.target_modules,
|
408 |
+
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
409 |
+
"gradient_checkpointing": self.gradient_checkpointing,
|
410 |
+
"checkpointing_steps": self.checkpointing_steps,
|
411 |
+
"checkpointing_limit": self.checkpointing_limit,
|
412 |
+
"resume_from_checkpoint": self.resume_from_checkpoint,
|
413 |
+
"enable_slicing": self.enable_slicing,
|
414 |
+
"enable_tiling": self.enable_tiling,
|
415 |
+
},
|
416 |
+
"optimizer_arguments": {
|
417 |
+
"optimizer": self.optimizer,
|
418 |
+
"use_8bit_bnb": self.use_8bit_bnb,
|
419 |
+
"lr": self.lr,
|
420 |
+
"scale_lr": self.scale_lr,
|
421 |
+
"lr_scheduler": self.lr_scheduler,
|
422 |
+
"lr_warmup_steps": self.lr_warmup_steps,
|
423 |
+
"lr_num_cycles": self.lr_num_cycles,
|
424 |
+
"lr_power": self.lr_power,
|
425 |
+
"beta1": self.beta1,
|
426 |
+
"beta2": self.beta2,
|
427 |
+
"beta3": self.beta3,
|
428 |
+
"weight_decay": self.weight_decay,
|
429 |
+
"epsilon": self.epsilon,
|
430 |
+
"max_grad_norm": self.max_grad_norm,
|
431 |
+
},
|
432 |
+
"validation_arguments": {
|
433 |
+
"validation_prompts": self.validation_prompts,
|
434 |
+
"validation_images": self.validation_images,
|
435 |
+
"validation_videos": self.validation_videos,
|
436 |
+
"num_validation_videos_per_prompt": self.num_validation_videos_per_prompt,
|
437 |
+
"validation_every_n_epochs": self.validation_every_n_epochs,
|
438 |
+
"validation_every_n_steps": self.validation_every_n_steps,
|
439 |
+
"enable_model_cpu_offload": self.enable_model_cpu_offload,
|
440 |
+
"validation_frame_rate": self.validation_frame_rate,
|
441 |
+
},
|
442 |
+
"miscellaneous_arguments": {
|
443 |
+
"tracker_name": self.tracker_name,
|
444 |
+
"push_to_hub": self.push_to_hub,
|
445 |
+
"hub_token": self.hub_token,
|
446 |
+
"hub_model_id": self.hub_model_id,
|
447 |
+
"output_dir": self.output_dir,
|
448 |
+
"logging_dir": self.logging_dir,
|
449 |
+
"allow_tf32": self.allow_tf32,
|
450 |
+
"nccl_timeout": self.nccl_timeout,
|
451 |
+
"report_to": self.report_to,
|
452 |
+
},
|
453 |
+
}
|
454 |
+
|
455 |
+
|
456 |
+
# TODO(aryan): handle more informative messages
|
457 |
+
_IS_ARGUMENTS_REQUIRED = "--list_models" not in sys.argv
|
458 |
+
|
459 |
+
|
460 |
+
def parse_arguments() -> Args:
|
461 |
+
parser = argparse.ArgumentParser()
|
462 |
+
|
463 |
+
if _IS_ARGUMENTS_REQUIRED:
|
464 |
+
_add_model_arguments(parser)
|
465 |
+
_add_dataset_arguments(parser)
|
466 |
+
_add_dataloader_arguments(parser)
|
467 |
+
_add_diffusion_arguments(parser)
|
468 |
+
_add_training_arguments(parser)
|
469 |
+
_add_optimizer_arguments(parser)
|
470 |
+
_add_validation_arguments(parser)
|
471 |
+
_add_miscellaneous_arguments(parser)
|
472 |
+
|
473 |
+
args = parser.parse_args()
|
474 |
+
return _map_to_args_type(args)
|
475 |
+
else:
|
476 |
+
_add_helper_arguments(parser)
|
477 |
+
|
478 |
+
args = parser.parse_args()
|
479 |
+
_display_helper_messages(args)
|
480 |
+
sys.exit(0)
|
481 |
+
|
482 |
+
|
483 |
+
def validate_args(args: Args):
|
484 |
+
_validated_model_args(args)
|
485 |
+
_validate_training_args(args)
|
486 |
+
_validate_validation_args(args)
|
487 |
+
|
488 |
+
|
489 |
+
def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
|
490 |
+
parser.add_argument(
|
491 |
+
"--model_name",
|
492 |
+
type=str,
|
493 |
+
required=True,
|
494 |
+
choices=list(SUPPORTED_MODEL_CONFIGS.keys()),
|
495 |
+
help="Name of model to train.",
|
496 |
+
)
|
497 |
+
parser.add_argument(
|
498 |
+
"--pretrained_model_name_or_path",
|
499 |
+
type=str,
|
500 |
+
required=True,
|
501 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
502 |
+
)
|
503 |
+
parser.add_argument(
|
504 |
+
"--revision",
|
505 |
+
type=str,
|
506 |
+
default=None,
|
507 |
+
required=False,
|
508 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
509 |
+
)
|
510 |
+
parser.add_argument(
|
511 |
+
"--variant",
|
512 |
+
type=str,
|
513 |
+
default=None,
|
514 |
+
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
515 |
+
)
|
516 |
+
parser.add_argument(
|
517 |
+
"--cache_dir",
|
518 |
+
type=str,
|
519 |
+
default=None,
|
520 |
+
help="The directory where the downloaded models and datasets will be stored.",
|
521 |
+
)
|
522 |
+
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", help="Data type for the text encoder.")
|
523 |
+
parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16", help="Data type for the text encoder 2.")
|
524 |
+
parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
|
525 |
+
parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.")
|
526 |
+
parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.")
|
527 |
+
parser.add_argument(
|
528 |
+
"--layerwise_upcasting_modules",
|
529 |
+
type=str,
|
530 |
+
default=[],
|
531 |
+
nargs="+",
|
532 |
+
choices=["transformer"],
|
533 |
+
help="Modules that should have fp8 storage weights but higher precision computation.",
|
534 |
+
)
|
535 |
+
parser.add_argument(
|
536 |
+
"--layerwise_upcasting_storage_dtype",
|
537 |
+
type=str,
|
538 |
+
default="float8_e4m3fn",
|
539 |
+
choices=["float8_e4m3fn", "float8_e5m2"],
|
540 |
+
help="Data type for the layerwise upcasting storage.",
|
541 |
+
)
|
542 |
+
parser.add_argument(
|
543 |
+
"--layerwise_upcasting_skip_modules_pattern",
|
544 |
+
type=str,
|
545 |
+
default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"],
|
546 |
+
nargs="+",
|
547 |
+
help="Modules to skip for layerwise upcasting.",
|
548 |
+
)
|
549 |
+
|
550 |
+
|
551 |
+
def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
|
552 |
+
def parse_resolution_bucket(resolution_bucket: str) -> Tuple[int, ...]:
|
553 |
+
return tuple(map(int, resolution_bucket.split("x")))
|
554 |
+
|
555 |
+
def parse_image_resolution_bucket(resolution_bucket: str) -> Tuple[int, int]:
|
556 |
+
resolution_bucket = parse_resolution_bucket(resolution_bucket)
|
557 |
+
assert (
|
558 |
+
len(resolution_bucket) == 2
|
559 |
+
), f"Expected 2D resolution bucket, got {len(resolution_bucket)}D resolution bucket"
|
560 |
+
return resolution_bucket
|
561 |
+
|
562 |
+
def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int]:
|
563 |
+
resolution_bucket = parse_resolution_bucket(resolution_bucket)
|
564 |
+
assert (
|
565 |
+
len(resolution_bucket) == 3
|
566 |
+
), f"Expected 3D resolution bucket, got {len(resolution_bucket)}D resolution bucket"
|
567 |
+
return resolution_bucket
|
568 |
+
|
569 |
+
parser.add_argument(
|
570 |
+
"--data_root",
|
571 |
+
type=str,
|
572 |
+
required=True,
|
573 |
+
help=("A folder containing the training data."),
|
574 |
+
)
|
575 |
+
parser.add_argument(
|
576 |
+
"--dataset_file",
|
577 |
+
type=str,
|
578 |
+
default=None,
|
579 |
+
help=("Path to a CSV file if loading prompts/video paths using this format."),
|
580 |
+
)
|
581 |
+
parser.add_argument(
|
582 |
+
"--video_column",
|
583 |
+
type=str,
|
584 |
+
default="video",
|
585 |
+
help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.",
|
586 |
+
)
|
587 |
+
parser.add_argument(
|
588 |
+
"--caption_column",
|
589 |
+
type=str,
|
590 |
+
default="text",
|
591 |
+
help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
|
592 |
+
)
|
593 |
+
parser.add_argument(
|
594 |
+
"--id_token",
|
595 |
+
type=str,
|
596 |
+
default=None,
|
597 |
+
help="Identifier token appended to the start of each prompt if provided.",
|
598 |
+
)
|
599 |
+
parser.add_argument(
|
600 |
+
"--image_resolution_buckets",
|
601 |
+
type=parse_image_resolution_bucket,
|
602 |
+
default=None,
|
603 |
+
nargs="+",
|
604 |
+
help="Resolution buckets for images.",
|
605 |
+
)
|
606 |
+
parser.add_argument(
|
607 |
+
"--video_resolution_buckets",
|
608 |
+
type=parse_video_resolution_bucket,
|
609 |
+
default=None,
|
610 |
+
nargs="+",
|
611 |
+
help="Resolution buckets for videos.",
|
612 |
+
)
|
613 |
+
parser.add_argument(
|
614 |
+
"--video_reshape_mode",
|
615 |
+
type=str,
|
616 |
+
default=None,
|
617 |
+
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
|
618 |
+
)
|
619 |
+
parser.add_argument(
|
620 |
+
"--caption_dropout_p",
|
621 |
+
type=float,
|
622 |
+
default=0.00,
|
623 |
+
help="Probability of dropout for the caption tokens.",
|
624 |
+
)
|
625 |
+
parser.add_argument(
|
626 |
+
"--caption_dropout_technique",
|
627 |
+
type=str,
|
628 |
+
default="empty",
|
629 |
+
choices=["empty", "zero"],
|
630 |
+
help="Technique to use for caption dropout.",
|
631 |
+
)
|
632 |
+
parser.add_argument(
|
633 |
+
"--precompute_conditions",
|
634 |
+
action="store_true",
|
635 |
+
help="Whether or not to precompute the conditionings for the model.",
|
636 |
+
)
|
637 |
+
parser.add_argument(
|
638 |
+
"--remove_common_llm_caption_prefixes",
|
639 |
+
action="store_true",
|
640 |
+
help="Whether or not to remove common LLM caption prefixes.",
|
641 |
+
)
|
642 |
+
|
643 |
+
|
644 |
+
def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None:
|
645 |
+
parser.add_argument(
|
646 |
+
"--dataloader_num_workers",
|
647 |
+
type=int,
|
648 |
+
default=0,
|
649 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
650 |
+
)
|
651 |
+
parser.add_argument(
|
652 |
+
"--pin_memory",
|
653 |
+
action="store_true",
|
654 |
+
help="Whether or not to use the pinned memory setting in pytorch dataloader.",
|
655 |
+
)
|
656 |
+
|
657 |
+
|
658 |
+
def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None:
|
659 |
+
parser.add_argument(
|
660 |
+
"--flow_resolution_shifting",
|
661 |
+
action="store_true",
|
662 |
+
help="Resolution-dependent shifting of timestep schedules.",
|
663 |
+
)
|
664 |
+
parser.add_argument(
|
665 |
+
"--flow_base_seq_len",
|
666 |
+
type=int,
|
667 |
+
default=256,
|
668 |
+
help="Base image/video sequence length for the diffusion model.",
|
669 |
+
)
|
670 |
+
parser.add_argument(
|
671 |
+
"--flow_max_seq_len",
|
672 |
+
type=int,
|
673 |
+
default=4096,
|
674 |
+
help="Maximum image/video sequence length for the diffusion model.",
|
675 |
+
)
|
676 |
+
parser.add_argument(
|
677 |
+
"--flow_base_shift",
|
678 |
+
type=float,
|
679 |
+
default=0.5,
|
680 |
+
help="Base shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)",
|
681 |
+
)
|
682 |
+
parser.add_argument(
|
683 |
+
"--flow_max_shift",
|
684 |
+
type=float,
|
685 |
+
default=1.15,
|
686 |
+
help="Maximum shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)",
|
687 |
+
)
|
688 |
+
parser.add_argument(
|
689 |
+
"--flow_shift",
|
690 |
+
type=float,
|
691 |
+
default=1.0,
|
692 |
+
help="Shift value to use for the flow matching timestep schedule.",
|
693 |
+
)
|
694 |
+
parser.add_argument(
|
695 |
+
"--flow_weighting_scheme",
|
696 |
+
type=str,
|
697 |
+
default="none",
|
698 |
+
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
|
699 |
+
help='We default to the "none" weighting scheme for uniform sampling and uniform loss',
|
700 |
+
)
|
701 |
+
parser.add_argument(
|
702 |
+
"--flow_logit_mean",
|
703 |
+
type=float,
|
704 |
+
default=0.0,
|
705 |
+
help="Mean to use when using the `'logit_normal'` weighting scheme.",
|
706 |
+
)
|
707 |
+
parser.add_argument(
|
708 |
+
"--flow_logit_std",
|
709 |
+
type=float,
|
710 |
+
default=1.0,
|
711 |
+
help="Standard deviation to use when using the `'logit_normal'` weighting scheme.",
|
712 |
+
)
|
713 |
+
parser.add_argument(
|
714 |
+
"--flow_mode_scale",
|
715 |
+
type=float,
|
716 |
+
default=1.29,
|
717 |
+
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
718 |
+
)
|
719 |
+
|
720 |
+
|
721 |
+
def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
|
722 |
+
# TODO: support full finetuning and other kinds
|
723 |
+
parser.add_argument(
|
724 |
+
"--training_type",
|
725 |
+
type=str,
|
726 |
+
choices=["lora", "full-finetune"],
|
727 |
+
required=True,
|
728 |
+
help="Type of training to perform. Choose between ['lora', 'full-finetune']",
|
729 |
+
)
|
730 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
731 |
+
parser.add_argument(
|
732 |
+
"--batch_size",
|
733 |
+
type=int,
|
734 |
+
default=1,
|
735 |
+
help="Batch size (per device) for the training dataloader.",
|
736 |
+
)
|
737 |
+
parser.add_argument("--train_epochs", type=int, default=1, help="Number of training epochs.")
|
738 |
+
parser.add_argument(
|
739 |
+
"--train_steps",
|
740 |
+
type=int,
|
741 |
+
default=None,
|
742 |
+
help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
|
743 |
+
)
|
744 |
+
parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.")
|
745 |
+
parser.add_argument(
|
746 |
+
"--lora_alpha",
|
747 |
+
type=int,
|
748 |
+
default=64,
|
749 |
+
help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
|
750 |
+
)
|
751 |
+
parser.add_argument(
|
752 |
+
"--target_modules",
|
753 |
+
type=str,
|
754 |
+
default=["to_k", "to_q", "to_v", "to_out.0"],
|
755 |
+
nargs="+",
|
756 |
+
help="The target modules for LoRA.",
|
757 |
+
)
|
758 |
+
parser.add_argument(
|
759 |
+
"--gradient_accumulation_steps",
|
760 |
+
type=int,
|
761 |
+
default=1,
|
762 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
763 |
+
)
|
764 |
+
parser.add_argument(
|
765 |
+
"--gradient_checkpointing",
|
766 |
+
action="store_true",
|
767 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
768 |
+
)
|
769 |
+
parser.add_argument(
|
770 |
+
"--checkpointing_steps",
|
771 |
+
type=int,
|
772 |
+
default=500,
|
773 |
+
help=(
|
774 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
775 |
+
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
776 |
+
" training using `--resume_from_checkpoint`."
|
777 |
+
),
|
778 |
+
)
|
779 |
+
parser.add_argument(
|
780 |
+
"--checkpointing_limit",
|
781 |
+
type=int,
|
782 |
+
default=None,
|
783 |
+
help=("Max number of checkpoints to store."),
|
784 |
+
)
|
785 |
+
parser.add_argument(
|
786 |
+
"--resume_from_checkpoint",
|
787 |
+
type=str,
|
788 |
+
default=None,
|
789 |
+
help=(
|
790 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
791 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
792 |
+
),
|
793 |
+
)
|
794 |
+
parser.add_argument(
|
795 |
+
"--enable_slicing",
|
796 |
+
action="store_true",
|
797 |
+
help="Whether or not to use VAE slicing for saving memory.",
|
798 |
+
)
|
799 |
+
parser.add_argument(
|
800 |
+
"--enable_tiling",
|
801 |
+
action="store_true",
|
802 |
+
help="Whether or not to use VAE tiling for saving memory.",
|
803 |
+
)
|
804 |
+
|
805 |
+
|
806 |
+
def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None:
|
807 |
+
parser.add_argument(
|
808 |
+
"--lr",
|
809 |
+
type=float,
|
810 |
+
default=1e-4,
|
811 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
812 |
+
)
|
813 |
+
parser.add_argument(
|
814 |
+
"--scale_lr",
|
815 |
+
action="store_true",
|
816 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
817 |
+
)
|
818 |
+
parser.add_argument(
|
819 |
+
"--lr_scheduler",
|
820 |
+
type=str,
|
821 |
+
default="constant",
|
822 |
+
help=(
|
823 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
824 |
+
' "constant", "constant_with_warmup"]'
|
825 |
+
),
|
826 |
+
)
|
827 |
+
parser.add_argument(
|
828 |
+
"--lr_warmup_steps",
|
829 |
+
type=int,
|
830 |
+
default=500,
|
831 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
832 |
+
)
|
833 |
+
parser.add_argument(
|
834 |
+
"--lr_num_cycles",
|
835 |
+
type=int,
|
836 |
+
default=1,
|
837 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
838 |
+
)
|
839 |
+
parser.add_argument(
|
840 |
+
"--lr_power",
|
841 |
+
type=float,
|
842 |
+
default=1.0,
|
843 |
+
help="Power factor of the polynomial scheduler.",
|
844 |
+
)
|
845 |
+
parser.add_argument(
|
846 |
+
"--optimizer",
|
847 |
+
type=lambda s: s.lower(),
|
848 |
+
default="adam",
|
849 |
+
choices=["adam", "adamw"],
|
850 |
+
help=("The optimizer type to use."),
|
851 |
+
)
|
852 |
+
parser.add_argument(
|
853 |
+
"--use_8bit_bnb",
|
854 |
+
action="store_true",
|
855 |
+
help=("Whether to use 8bit variant of the `--optimizer` using `bitsandbytes`."),
|
856 |
+
)
|
857 |
+
parser.add_argument(
|
858 |
+
"--beta1",
|
859 |
+
type=float,
|
860 |
+
default=0.9,
|
861 |
+
help="The beta1 parameter for the Adam and Prodigy optimizers.",
|
862 |
+
)
|
863 |
+
parser.add_argument(
|
864 |
+
"--beta2",
|
865 |
+
type=float,
|
866 |
+
default=0.95,
|
867 |
+
help="The beta2 parameter for the Adam and Prodigy optimizers.",
|
868 |
+
)
|
869 |
+
parser.add_argument(
|
870 |
+
"--beta3",
|
871 |
+
type=float,
|
872 |
+
default=None,
|
873 |
+
help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
|
874 |
+
)
|
875 |
+
parser.add_argument(
|
876 |
+
"--weight_decay",
|
877 |
+
type=float,
|
878 |
+
default=1e-04,
|
879 |
+
help="Weight decay to use for optimizer.",
|
880 |
+
)
|
881 |
+
parser.add_argument(
|
882 |
+
"--epsilon",
|
883 |
+
type=float,
|
884 |
+
default=1e-8,
|
885 |
+
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
|
886 |
+
)
|
887 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
888 |
+
|
889 |
+
|
890 |
+
def _add_validation_arguments(parser: argparse.ArgumentParser) -> None:
|
891 |
+
parser.add_argument(
|
892 |
+
"--validation_prompts",
|
893 |
+
type=str,
|
894 |
+
default=None,
|
895 |
+
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
896 |
+
)
|
897 |
+
parser.add_argument(
|
898 |
+
"--validation_images",
|
899 |
+
type=str,
|
900 |
+
default=None,
|
901 |
+
help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
|
902 |
+
)
|
903 |
+
parser.add_argument(
|
904 |
+
"--validation_videos",
|
905 |
+
type=str,
|
906 |
+
default=None,
|
907 |
+
help="One or more video path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
|
908 |
+
)
|
909 |
+
parser.add_argument(
|
910 |
+
"--validation_separator",
|
911 |
+
type=str,
|
912 |
+
default=":::",
|
913 |
+
help="String that separates multiple validation prompts",
|
914 |
+
)
|
915 |
+
parser.add_argument(
|
916 |
+
"--num_validation_videos",
|
917 |
+
type=int,
|
918 |
+
default=1,
|
919 |
+
help="Number of videos that should be generated during validation per `validation_prompt`.",
|
920 |
+
)
|
921 |
+
parser.add_argument(
|
922 |
+
"--validation_epochs",
|
923 |
+
type=int,
|
924 |
+
default=None,
|
925 |
+
help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.",
|
926 |
+
)
|
927 |
+
parser.add_argument(
|
928 |
+
"--validation_steps",
|
929 |
+
type=int,
|
930 |
+
default=None,
|
931 |
+
help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
|
932 |
+
)
|
933 |
+
parser.add_argument(
|
934 |
+
"--validation_frame_rate",
|
935 |
+
type=int,
|
936 |
+
default=25,
|
937 |
+
help="Frame rate to use for the validation videos.",
|
938 |
+
)
|
939 |
+
parser.add_argument(
|
940 |
+
"--enable_model_cpu_offload",
|
941 |
+
action="store_true",
|
942 |
+
help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
|
943 |
+
)
|
944 |
+
|
945 |
+
|
946 |
+
def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
|
947 |
+
parser.add_argument("--tracker_name", type=str, default="finetrainers", help="Project tracker name")
|
948 |
+
parser.add_argument(
|
949 |
+
"--push_to_hub",
|
950 |
+
action="store_true",
|
951 |
+
help="Whether or not to push the model to the Hub.",
|
952 |
+
)
|
953 |
+
parser.add_argument(
|
954 |
+
"--hub_token",
|
955 |
+
type=str,
|
956 |
+
default=None,
|
957 |
+
help="The token to use to push to the Model Hub.",
|
958 |
+
)
|
959 |
+
parser.add_argument(
|
960 |
+
"--hub_model_id",
|
961 |
+
type=str,
|
962 |
+
default=None,
|
963 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
964 |
+
)
|
965 |
+
parser.add_argument(
|
966 |
+
"--output_dir",
|
967 |
+
type=str,
|
968 |
+
default="finetrainers-training",
|
969 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
970 |
+
)
|
971 |
+
parser.add_argument(
|
972 |
+
"--logging_dir",
|
973 |
+
type=str,
|
974 |
+
default="logs",
|
975 |
+
help="Directory where logs are stored.",
|
976 |
+
)
|
977 |
+
parser.add_argument(
|
978 |
+
"--allow_tf32",
|
979 |
+
action="store_true",
|
980 |
+
help=(
|
981 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
982 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
983 |
+
),
|
984 |
+
)
|
985 |
+
parser.add_argument(
|
986 |
+
"--nccl_timeout",
|
987 |
+
type=int,
|
988 |
+
default=600,
|
989 |
+
help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.",
|
990 |
+
)
|
991 |
+
parser.add_argument(
|
992 |
+
"--report_to",
|
993 |
+
type=str,
|
994 |
+
default="none",
|
995 |
+
choices=["none", "wandb"],
|
996 |
+
help="The integration to report the results and logs to.",
|
997 |
+
)
|
998 |
+
|
999 |
+
|
1000 |
+
def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
|
1001 |
+
parser.add_argument(
|
1002 |
+
"--list_models",
|
1003 |
+
action="store_true",
|
1004 |
+
help="List all the supported models.",
|
1005 |
+
)
|
1006 |
+
|
1007 |
+
|
1008 |
+
_DTYPE_MAP = {
|
1009 |
+
"bf16": torch.bfloat16,
|
1010 |
+
"fp16": torch.float16,
|
1011 |
+
"fp32": torch.float32,
|
1012 |
+
"float8_e4m3fn": torch.float8_e4m3fn,
|
1013 |
+
"float8_e5m2": torch.float8_e5m2,
|
1014 |
+
}
|
1015 |
+
|
1016 |
+
|
1017 |
+
def _map_to_args_type(args: Dict[str, Any]) -> Args:
|
1018 |
+
result_args = Args()
|
1019 |
+
|
1020 |
+
# Model arguments
|
1021 |
+
result_args.model_name = args.model_name
|
1022 |
+
result_args.pretrained_model_name_or_path = args.pretrained_model_name_or_path
|
1023 |
+
result_args.revision = args.revision
|
1024 |
+
result_args.variant = args.variant
|
1025 |
+
result_args.cache_dir = args.cache_dir
|
1026 |
+
result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype]
|
1027 |
+
result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype]
|
1028 |
+
result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
|
1029 |
+
result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype]
|
1030 |
+
result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype]
|
1031 |
+
result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules
|
1032 |
+
result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype]
|
1033 |
+
result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern
|
1034 |
+
|
1035 |
+
# Dataset arguments
|
1036 |
+
if args.data_root is None and args.dataset_file is None:
|
1037 |
+
raise ValueError("At least one of `data_root` or `dataset_file` should be provided.")
|
1038 |
+
|
1039 |
+
result_args.data_root = args.data_root
|
1040 |
+
result_args.dataset_file = args.dataset_file
|
1041 |
+
result_args.video_column = args.video_column
|
1042 |
+
result_args.caption_column = args.caption_column
|
1043 |
+
result_args.id_token = args.id_token
|
1044 |
+
result_args.image_resolution_buckets = args.image_resolution_buckets or DEFAULT_IMAGE_RESOLUTION_BUCKETS
|
1045 |
+
result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS
|
1046 |
+
result_args.video_reshape_mode = args.video_reshape_mode
|
1047 |
+
result_args.caption_dropout_p = args.caption_dropout_p
|
1048 |
+
result_args.caption_dropout_technique = args.caption_dropout_technique
|
1049 |
+
result_args.precompute_conditions = args.precompute_conditions
|
1050 |
+
result_args.remove_common_llm_caption_prefixes = args.remove_common_llm_caption_prefixes
|
1051 |
+
|
1052 |
+
# Dataloader arguments
|
1053 |
+
result_args.dataloader_num_workers = args.dataloader_num_workers
|
1054 |
+
result_args.pin_memory = args.pin_memory
|
1055 |
+
|
1056 |
+
# Diffusion arguments
|
1057 |
+
result_args.flow_resolution_shifting = args.flow_resolution_shifting
|
1058 |
+
result_args.flow_base_seq_len = args.flow_base_seq_len
|
1059 |
+
result_args.flow_max_seq_len = args.flow_max_seq_len
|
1060 |
+
result_args.flow_base_shift = args.flow_base_shift
|
1061 |
+
result_args.flow_max_shift = args.flow_max_shift
|
1062 |
+
result_args.flow_shift = args.flow_shift
|
1063 |
+
result_args.flow_weighting_scheme = args.flow_weighting_scheme
|
1064 |
+
result_args.flow_logit_mean = args.flow_logit_mean
|
1065 |
+
result_args.flow_logit_std = args.flow_logit_std
|
1066 |
+
result_args.flow_mode_scale = args.flow_mode_scale
|
1067 |
+
|
1068 |
+
# Training arguments
|
1069 |
+
result_args.training_type = args.training_type
|
1070 |
+
result_args.seed = args.seed
|
1071 |
+
result_args.batch_size = args.batch_size
|
1072 |
+
result_args.train_epochs = args.train_epochs
|
1073 |
+
result_args.train_steps = args.train_steps
|
1074 |
+
result_args.rank = args.rank
|
1075 |
+
result_args.lora_alpha = args.lora_alpha
|
1076 |
+
result_args.target_modules = args.target_modules
|
1077 |
+
result_args.gradient_accumulation_steps = args.gradient_accumulation_steps
|
1078 |
+
result_args.gradient_checkpointing = args.gradient_checkpointing
|
1079 |
+
result_args.checkpointing_steps = args.checkpointing_steps
|
1080 |
+
result_args.checkpointing_limit = args.checkpointing_limit
|
1081 |
+
result_args.resume_from_checkpoint = args.resume_from_checkpoint
|
1082 |
+
result_args.enable_slicing = args.enable_slicing
|
1083 |
+
result_args.enable_tiling = args.enable_tiling
|
1084 |
+
|
1085 |
+
# Optimizer arguments
|
1086 |
+
result_args.optimizer = args.optimizer or "adamw"
|
1087 |
+
result_args.use_8bit_bnb = args.use_8bit_bnb
|
1088 |
+
result_args.lr = args.lr or 1e-4
|
1089 |
+
result_args.scale_lr = args.scale_lr
|
1090 |
+
result_args.lr_scheduler = args.lr_scheduler
|
1091 |
+
result_args.lr_warmup_steps = args.lr_warmup_steps
|
1092 |
+
result_args.lr_num_cycles = args.lr_num_cycles
|
1093 |
+
result_args.lr_power = args.lr_power
|
1094 |
+
result_args.beta1 = args.beta1
|
1095 |
+
result_args.beta2 = args.beta2
|
1096 |
+
result_args.beta3 = args.beta3
|
1097 |
+
result_args.weight_decay = args.weight_decay
|
1098 |
+
result_args.epsilon = args.epsilon
|
1099 |
+
result_args.max_grad_norm = args.max_grad_norm
|
1100 |
+
|
1101 |
+
# Validation arguments
|
1102 |
+
validation_prompts = args.validation_prompts.split(args.validation_separator) if args.validation_prompts else []
|
1103 |
+
validation_images = args.validation_images.split(args.validation_separator) if args.validation_images else None
|
1104 |
+
validation_videos = args.validation_videos.split(args.validation_separator) if args.validation_videos else None
|
1105 |
+
stripped_validation_prompts = []
|
1106 |
+
validation_heights = []
|
1107 |
+
validation_widths = []
|
1108 |
+
validation_num_frames = []
|
1109 |
+
for prompt in validation_prompts:
|
1110 |
+
prompt: str
|
1111 |
+
prompt = prompt.strip()
|
1112 |
+
actual_prompt, separator, resolution = prompt.rpartition("@@@")
|
1113 |
+
stripped_validation_prompts.append(actual_prompt)
|
1114 |
+
num_frames, height, width = None, None, None
|
1115 |
+
if len(resolution) > 0:
|
1116 |
+
num_frames, height, width = map(int, resolution.split("x"))
|
1117 |
+
validation_num_frames.append(num_frames)
|
1118 |
+
validation_heights.append(height)
|
1119 |
+
validation_widths.append(width)
|
1120 |
+
|
1121 |
+
if validation_images is None:
|
1122 |
+
validation_images = [None] * len(validation_prompts)
|
1123 |
+
if validation_videos is None:
|
1124 |
+
validation_videos = [None] * len(validation_prompts)
|
1125 |
+
|
1126 |
+
result_args.validation_prompts = stripped_validation_prompts
|
1127 |
+
result_args.validation_heights = validation_heights
|
1128 |
+
result_args.validation_widths = validation_widths
|
1129 |
+
result_args.validation_num_frames = validation_num_frames
|
1130 |
+
result_args.validation_images = validation_images
|
1131 |
+
result_args.validation_videos = validation_videos
|
1132 |
+
|
1133 |
+
result_args.num_validation_videos_per_prompt = args.num_validation_videos
|
1134 |
+
result_args.validation_every_n_epochs = args.validation_epochs
|
1135 |
+
result_args.validation_every_n_steps = args.validation_steps
|
1136 |
+
result_args.enable_model_cpu_offload = args.enable_model_cpu_offload
|
1137 |
+
result_args.validation_frame_rate = args.validation_frame_rate
|
1138 |
+
|
1139 |
+
# Miscellaneous arguments
|
1140 |
+
result_args.tracker_name = args.tracker_name
|
1141 |
+
result_args.push_to_hub = args.push_to_hub
|
1142 |
+
result_args.hub_token = args.hub_token
|
1143 |
+
result_args.hub_model_id = args.hub_model_id
|
1144 |
+
result_args.output_dir = args.output_dir
|
1145 |
+
result_args.logging_dir = args.logging_dir
|
1146 |
+
result_args.allow_tf32 = args.allow_tf32
|
1147 |
+
result_args.nccl_timeout = args.nccl_timeout
|
1148 |
+
result_args.report_to = args.report_to
|
1149 |
+
|
1150 |
+
return result_args
|
1151 |
+
|
1152 |
+
|
1153 |
+
def _validated_model_args(args: Args):
|
1154 |
+
if args.training_type == "full-finetune":
|
1155 |
+
assert (
|
1156 |
+
"transformer" not in args.layerwise_upcasting_modules
|
1157 |
+
), "Layerwise upcasting is not supported for full-finetune training"
|
1158 |
+
|
1159 |
+
|
1160 |
+
def _validate_training_args(args: Args):
|
1161 |
+
if args.training_type == "lora":
|
1162 |
+
assert args.rank is not None, "Rank is required for LoRA training"
|
1163 |
+
assert args.lora_alpha is not None, "LoRA alpha is required for LoRA training"
|
1164 |
+
assert (
|
1165 |
+
args.target_modules is not None and len(args.target_modules) > 0
|
1166 |
+
), "Target modules are required for LoRA training"
|
1167 |
+
|
1168 |
+
|
1169 |
+
def _validate_validation_args(args: Args):
|
1170 |
+
assert args.validation_prompts is not None, "Validation prompts are required for validation"
|
1171 |
+
if args.validation_images is not None:
|
1172 |
+
assert len(args.validation_images) == len(
|
1173 |
+
args.validation_prompts
|
1174 |
+
), "Validation images and prompts should be of same length"
|
1175 |
+
if args.validation_videos is not None:
|
1176 |
+
assert len(args.validation_videos) == len(
|
1177 |
+
args.validation_prompts
|
1178 |
+
), "Validation videos and prompts should be of same length"
|
1179 |
+
assert len(args.validation_prompts) == len(
|
1180 |
+
args.validation_heights
|
1181 |
+
), "Validation prompts and heights should be of same length"
|
1182 |
+
assert len(args.validation_prompts) == len(
|
1183 |
+
args.validation_widths
|
1184 |
+
), "Validation prompts and widths should be of same length"
|
1185 |
+
|
1186 |
+
|
1187 |
+
def _display_helper_messages(args: argparse.Namespace):
|
1188 |
+
if args.list_models:
|
1189 |
+
print("Supported models:")
|
1190 |
+
for index, model_name in enumerate(SUPPORTED_MODEL_CONFIGS.keys()):
|
1191 |
+
print(f" {index + 1}. {model_name}")
|
finetrainers/constants.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
DEFAULT_HEIGHT_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
|
5 |
+
DEFAULT_WIDTH_BUCKETS = [256, 320, 384, 480, 512, 576, 720, 768, 960, 1024, 1280, 1536]
|
6 |
+
DEFAULT_FRAME_BUCKETS = [49]
|
7 |
+
|
8 |
+
DEFAULT_IMAGE_RESOLUTION_BUCKETS = []
|
9 |
+
for height in DEFAULT_HEIGHT_BUCKETS:
|
10 |
+
for width in DEFAULT_WIDTH_BUCKETS:
|
11 |
+
DEFAULT_IMAGE_RESOLUTION_BUCKETS.append((height, width))
|
12 |
+
|
13 |
+
DEFAULT_VIDEO_RESOLUTION_BUCKETS = []
|
14 |
+
for frames in DEFAULT_FRAME_BUCKETS:
|
15 |
+
for height in DEFAULT_HEIGHT_BUCKETS:
|
16 |
+
for width in DEFAULT_WIDTH_BUCKETS:
|
17 |
+
DEFAULT_VIDEO_RESOLUTION_BUCKETS.append((frames, height, width))
|
18 |
+
|
19 |
+
|
20 |
+
FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO")
|
21 |
+
|
22 |
+
PRECOMPUTED_DIR_NAME = "precomputed"
|
23 |
+
PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions"
|
24 |
+
PRECOMPUTED_LATENTS_DIR_NAME = "latents"
|
25 |
+
|
26 |
+
MODEL_DESCRIPTION = r"""
|
27 |
+
\# {model_id} {training_type} finetune
|
28 |
+
|
29 |
+
<Gallery />
|
30 |
+
|
31 |
+
\#\# Model Description
|
32 |
+
|
33 |
+
This model is a {training_type} of the `{model_id}` model.
|
34 |
+
|
35 |
+
This model was trained using the `fine-video-trainers` library - a repository containing memory-optimized scripts for training video models with [Diffusers](https://github.com/huggingface/diffusers).
|
36 |
+
|
37 |
+
\#\# Download model
|
38 |
+
|
39 |
+
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
|
40 |
+
|
41 |
+
\#\# Usage
|
42 |
+
|
43 |
+
Requires [🧨 Diffusers](https://github.com/huggingface/diffusers) installed.
|
44 |
+
|
45 |
+
```python
|
46 |
+
{model_example}
|
47 |
+
```
|
48 |
+
|
49 |
+
For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
|
50 |
+
|
51 |
+
\#\# License
|
52 |
+
|
53 |
+
Please adhere to the license of the base model.
|
54 |
+
""".strip()
|
55 |
+
|
56 |
+
_COMMON_BEGINNING_PHRASES = (
|
57 |
+
"This video",
|
58 |
+
"The video",
|
59 |
+
"This clip",
|
60 |
+
"The clip",
|
61 |
+
"The animation",
|
62 |
+
"This image",
|
63 |
+
"The image",
|
64 |
+
"This picture",
|
65 |
+
"The picture",
|
66 |
+
)
|
67 |
+
_COMMON_CONTINUATION_WORDS = ("shows", "depicts", "features", "captures", "highlights", "introduces", "presents")
|
68 |
+
|
69 |
+
COMMON_LLM_START_PHRASES = (
|
70 |
+
"In the video,",
|
71 |
+
"In this video,",
|
72 |
+
"In this video clip,",
|
73 |
+
"In the clip,",
|
74 |
+
"Caption:",
|
75 |
+
*(
|
76 |
+
f"{beginning} {continuation}"
|
77 |
+
for beginning in _COMMON_BEGINNING_PHRASES
|
78 |
+
for continuation in _COMMON_CONTINUATION_WORDS
|
79 |
+
),
|
80 |
+
)
|
finetrainers/dataset.py
ADDED
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Dict, List, Optional, Tuple
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as TT
|
11 |
+
import torchvision.transforms.functional as TTF
|
12 |
+
from accelerate.logging import get_logger
|
13 |
+
from torch.utils.data import Dataset, Sampler
|
14 |
+
from torchvision import transforms
|
15 |
+
from torchvision.transforms import InterpolationMode
|
16 |
+
from torchvision.transforms.functional import resize
|
17 |
+
|
18 |
+
|
19 |
+
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
20 |
+
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
21 |
+
import decord # isort:skip
|
22 |
+
|
23 |
+
decord.bridge.set_bridge("torch")
|
24 |
+
|
25 |
+
from .constants import ( # noqa
|
26 |
+
COMMON_LLM_START_PHRASES,
|
27 |
+
PRECOMPUTED_CONDITIONS_DIR_NAME,
|
28 |
+
PRECOMPUTED_DIR_NAME,
|
29 |
+
PRECOMPUTED_LATENTS_DIR_NAME,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
logger = get_logger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
# TODO(aryan): This needs a refactor with separation of concerns.
|
37 |
+
# Images should be handled separately. Videos should be handled separately.
|
38 |
+
# Loading should be handled separately.
|
39 |
+
# Preprocessing (aspect ratio, resizing) should be handled separately.
|
40 |
+
# URL loading should be handled.
|
41 |
+
# Parquet format should be handled.
|
42 |
+
# Loading from ZIP should be handled.
|
43 |
+
class ImageOrVideoDataset(Dataset):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
data_root: str,
|
47 |
+
caption_column: str,
|
48 |
+
video_column: str,
|
49 |
+
resolution_buckets: List[Tuple[int, int, int]],
|
50 |
+
dataset_file: Optional[str] = None,
|
51 |
+
id_token: Optional[str] = None,
|
52 |
+
remove_llm_prefixes: bool = False,
|
53 |
+
) -> None:
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.data_root = Path(data_root)
|
57 |
+
self.dataset_file = dataset_file
|
58 |
+
self.caption_column = caption_column
|
59 |
+
self.video_column = video_column
|
60 |
+
self.id_token = f"{id_token.strip()} " if id_token else ""
|
61 |
+
self.resolution_buckets = resolution_buckets
|
62 |
+
|
63 |
+
# Four methods of loading data are supported.
|
64 |
+
# - Using a CSV: caption_column and video_column must be some column in the CSV. One could
|
65 |
+
# make use of other columns too, such as a motion score or aesthetic score, by modifying the
|
66 |
+
# logic in CSV processing.
|
67 |
+
# - Using two files containing line-separate captions and relative paths to videos.
|
68 |
+
# - Using a JSON file containing a list of dictionaries, where each dictionary has a `caption_column` and `video_column` key.
|
69 |
+
# - Using a JSONL file containing a list of line-separated dictionaries, where each dictionary has a `caption_column` and `video_column` key.
|
70 |
+
# For a more detailed explanation about preparing dataset format, checkout the README.
|
71 |
+
if dataset_file is None:
|
72 |
+
(
|
73 |
+
self.prompts,
|
74 |
+
self.video_paths,
|
75 |
+
) = self._load_dataset_from_local_path()
|
76 |
+
elif dataset_file.endswith(".csv"):
|
77 |
+
(
|
78 |
+
self.prompts,
|
79 |
+
self.video_paths,
|
80 |
+
) = self._load_dataset_from_csv()
|
81 |
+
elif dataset_file.endswith(".json"):
|
82 |
+
(
|
83 |
+
self.prompts,
|
84 |
+
self.video_paths,
|
85 |
+
) = self._load_dataset_from_json()
|
86 |
+
elif dataset_file.endswith(".jsonl"):
|
87 |
+
(
|
88 |
+
self.prompts,
|
89 |
+
self.video_paths,
|
90 |
+
) = self._load_dataset_from_jsonl()
|
91 |
+
else:
|
92 |
+
raise ValueError(
|
93 |
+
"Expected `--dataset_file` to be a path to a CSV file or a directory containing line-separated text prompts and video paths."
|
94 |
+
)
|
95 |
+
|
96 |
+
if len(self.video_paths) != len(self.prompts):
|
97 |
+
raise ValueError(
|
98 |
+
f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
|
99 |
+
)
|
100 |
+
|
101 |
+
# Clean LLM start phrases
|
102 |
+
if remove_llm_prefixes:
|
103 |
+
for i in range(len(self.prompts)):
|
104 |
+
self.prompts[i] = self.prompts[i].strip()
|
105 |
+
for phrase in COMMON_LLM_START_PHRASES:
|
106 |
+
if self.prompts[i].startswith(phrase):
|
107 |
+
self.prompts[i] = self.prompts[i].removeprefix(phrase).strip()
|
108 |
+
|
109 |
+
self.video_transforms = transforms.Compose(
|
110 |
+
[
|
111 |
+
transforms.Lambda(self.scale_transform),
|
112 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
113 |
+
]
|
114 |
+
)
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def scale_transform(x):
|
118 |
+
return x / 255.0
|
119 |
+
|
120 |
+
def __len__(self) -> int:
|
121 |
+
return len(self.video_paths)
|
122 |
+
|
123 |
+
def __getitem__(self, index: int) -> Dict[str, Any]:
|
124 |
+
if isinstance(index, list):
|
125 |
+
# Here, index is actually a list of data objects that we need to return.
|
126 |
+
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
|
127 |
+
# to have information about num_frames, height and width. Since this is not stored
|
128 |
+
# as metadata, we need to read the video to get this information. You could read this
|
129 |
+
# information without loading the full video in memory, but we do it anyway. In order
|
130 |
+
# to not load the video twice (once to get the metadata, and once to return the loaded video
|
131 |
+
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
|
132 |
+
# to yield, we yield the cache data instead of indices. So, this special check ensures
|
133 |
+
# that data is not loaded a second time. PRs are welcome for improvements.
|
134 |
+
return index
|
135 |
+
|
136 |
+
prompt = self.id_token + self.prompts[index]
|
137 |
+
|
138 |
+
video_path: Path = self.video_paths[index]
|
139 |
+
if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]:
|
140 |
+
video = self._preprocess_image(video_path)
|
141 |
+
else:
|
142 |
+
video = self._preprocess_video(video_path)
|
143 |
+
|
144 |
+
return {
|
145 |
+
"prompt": prompt,
|
146 |
+
"video": video,
|
147 |
+
"video_metadata": {
|
148 |
+
"num_frames": video.shape[0],
|
149 |
+
"height": video.shape[2],
|
150 |
+
"width": video.shape[3],
|
151 |
+
},
|
152 |
+
}
|
153 |
+
|
154 |
+
def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]:
|
155 |
+
if not self.data_root.exists():
|
156 |
+
raise ValueError("Root folder for videos does not exist")
|
157 |
+
|
158 |
+
prompt_path = self.data_root.joinpath(self.caption_column)
|
159 |
+
video_path = self.data_root.joinpath(self.video_column)
|
160 |
+
|
161 |
+
if not prompt_path.exists() or not prompt_path.is_file():
|
162 |
+
raise ValueError(
|
163 |
+
"Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts."
|
164 |
+
)
|
165 |
+
if not video_path.exists() or not video_path.is_file():
|
166 |
+
raise ValueError(
|
167 |
+
"Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory."
|
168 |
+
)
|
169 |
+
|
170 |
+
with open(prompt_path, "r", encoding="utf-8") as file:
|
171 |
+
prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
172 |
+
with open(video_path, "r", encoding="utf-8") as file:
|
173 |
+
video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
|
174 |
+
|
175 |
+
if any(not path.is_file() for path in video_paths):
|
176 |
+
raise ValueError(
|
177 |
+
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
|
178 |
+
)
|
179 |
+
|
180 |
+
return prompts, video_paths
|
181 |
+
|
182 |
+
def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]:
|
183 |
+
df = pd.read_csv(self.dataset_file)
|
184 |
+
prompts = df[self.caption_column].tolist()
|
185 |
+
video_paths = df[self.video_column].tolist()
|
186 |
+
video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths]
|
187 |
+
|
188 |
+
if any(not path.is_file() for path in video_paths):
|
189 |
+
raise ValueError(
|
190 |
+
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
|
191 |
+
)
|
192 |
+
|
193 |
+
return prompts, video_paths
|
194 |
+
|
195 |
+
def _load_dataset_from_json(self) -> Tuple[List[str], List[str]]:
|
196 |
+
with open(self.dataset_file, "r", encoding="utf-8") as file:
|
197 |
+
data = json.load(file)
|
198 |
+
|
199 |
+
prompts = [entry[self.caption_column] for entry in data]
|
200 |
+
video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data]
|
201 |
+
|
202 |
+
if any(not path.is_file() for path in video_paths):
|
203 |
+
raise ValueError(
|
204 |
+
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
|
205 |
+
)
|
206 |
+
|
207 |
+
return prompts, video_paths
|
208 |
+
|
209 |
+
def _load_dataset_from_jsonl(self) -> Tuple[List[str], List[str]]:
|
210 |
+
with open(self.dataset_file, "r", encoding="utf-8") as file:
|
211 |
+
data = [json.loads(line) for line in file]
|
212 |
+
|
213 |
+
prompts = [entry[self.caption_column] for entry in data]
|
214 |
+
video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data]
|
215 |
+
|
216 |
+
if any(not path.is_file() for path in video_paths):
|
217 |
+
raise ValueError(
|
218 |
+
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
|
219 |
+
)
|
220 |
+
|
221 |
+
return prompts, video_paths
|
222 |
+
|
223 |
+
def _preprocess_image(self, path: Path) -> torch.Tensor:
|
224 |
+
# TODO(aryan): Support alpha channel in future by whitening background
|
225 |
+
image = TTF.Image.open(path.as_posix()).convert("RGB")
|
226 |
+
image = TTF.to_tensor(image)
|
227 |
+
image = image * 2.0 - 1.0
|
228 |
+
image = image.unsqueeze(0).contiguous() # [C, H, W] -> [1, C, H, W] (1-frame video)
|
229 |
+
return image
|
230 |
+
|
231 |
+
def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
232 |
+
r"""
|
233 |
+
Loads a single video, or latent and prompt embedding, based on initialization parameters.
|
234 |
+
|
235 |
+
Returns a [F, C, H, W] video tensor.
|
236 |
+
"""
|
237 |
+
video_reader = decord.VideoReader(uri=path.as_posix())
|
238 |
+
video_num_frames = len(video_reader)
|
239 |
+
|
240 |
+
indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
|
241 |
+
frames = video_reader.get_batch(indices)
|
242 |
+
frames = frames[: self.max_num_frames].float()
|
243 |
+
frames = frames.permute(0, 3, 1, 2).contiguous()
|
244 |
+
frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
|
245 |
+
return frames
|
246 |
+
|
247 |
+
|
248 |
+
class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset):
|
249 |
+
def __init__(self, *args, **kwargs) -> None:
|
250 |
+
super().__init__(*args, **kwargs)
|
251 |
+
|
252 |
+
self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0]
|
253 |
+
|
254 |
+
def _preprocess_image(self, path: Path) -> torch.Tensor:
|
255 |
+
# TODO(aryan): Support alpha channel in future by whitening background
|
256 |
+
image = TTF.Image.open(path.as_posix()).convert("RGB")
|
257 |
+
image = TTF.to_tensor(image)
|
258 |
+
|
259 |
+
nearest_res = self._find_nearest_resolution(image.shape[1], image.shape[2])
|
260 |
+
image = resize(image, nearest_res)
|
261 |
+
|
262 |
+
image = image * 2.0 - 1.0
|
263 |
+
image = image.unsqueeze(0).contiguous()
|
264 |
+
return image
|
265 |
+
|
266 |
+
def _preprocess_video(self, path: Path) -> torch.Tensor:
|
267 |
+
video_reader = decord.VideoReader(uri=path.as_posix())
|
268 |
+
video_num_frames = len(video_reader)
|
269 |
+
print(f"ImageOrVideoDatasetWithResizing: self.resolution_buckets = ", self.resolution_buckets)
|
270 |
+
print(f"ImageOrVideoDatasetWithResizing: self.max_num_frames = ", self.max_num_frames)
|
271 |
+
print(f"ImageOrVideoDatasetWithResizing: video_num_frames = ", video_num_frames)
|
272 |
+
|
273 |
+
video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
|
274 |
+
|
275 |
+
if not video_buckets:
|
276 |
+
_, h, w = self.resolution_buckets[0]
|
277 |
+
video_buckets = [(1, h, w)]
|
278 |
+
|
279 |
+
nearest_frame_bucket = min(
|
280 |
+
video_buckets,
|
281 |
+
key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
|
282 |
+
default=video_buckets[0],
|
283 |
+
)[0]
|
284 |
+
|
285 |
+
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
|
286 |
+
|
287 |
+
frames = video_reader.get_batch(frame_indices)
|
288 |
+
frames = frames[:nearest_frame_bucket].float()
|
289 |
+
frames = frames.permute(0, 3, 1, 2).contiguous()
|
290 |
+
|
291 |
+
nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
|
292 |
+
frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
|
293 |
+
frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
|
294 |
+
|
295 |
+
return frames
|
296 |
+
|
297 |
+
def _find_nearest_resolution(self, height, width):
|
298 |
+
nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
|
299 |
+
return nearest_res[1], nearest_res[2]
|
300 |
+
|
301 |
+
|
302 |
+
class ImageOrVideoDatasetWithResizeAndRectangleCrop(ImageOrVideoDataset):
|
303 |
+
def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None:
|
304 |
+
super().__init__(*args, **kwargs)
|
305 |
+
|
306 |
+
self.video_reshape_mode = video_reshape_mode
|
307 |
+
self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0]
|
308 |
+
|
309 |
+
def _resize_for_rectangle_crop(self, arr, image_size):
|
310 |
+
reshape_mode = self.video_reshape_mode
|
311 |
+
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
312 |
+
arr = resize(
|
313 |
+
arr,
|
314 |
+
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
315 |
+
interpolation=InterpolationMode.BICUBIC,
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
arr = resize(
|
319 |
+
arr,
|
320 |
+
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
321 |
+
interpolation=InterpolationMode.BICUBIC,
|
322 |
+
)
|
323 |
+
|
324 |
+
h, w = arr.shape[2], arr.shape[3]
|
325 |
+
arr = arr.squeeze(0)
|
326 |
+
|
327 |
+
delta_h = h - image_size[0]
|
328 |
+
delta_w = w - image_size[1]
|
329 |
+
|
330 |
+
if reshape_mode == "random" or reshape_mode == "none":
|
331 |
+
top = np.random.randint(0, delta_h + 1)
|
332 |
+
left = np.random.randint(0, delta_w + 1)
|
333 |
+
elif reshape_mode == "center":
|
334 |
+
top, left = delta_h // 2, delta_w // 2
|
335 |
+
else:
|
336 |
+
raise NotImplementedError
|
337 |
+
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
338 |
+
return arr
|
339 |
+
|
340 |
+
def _preprocess_video(self, path: Path) -> torch.Tensor:
|
341 |
+
video_reader = decord.VideoReader(uri=path.as_posix())
|
342 |
+
video_num_frames = len(video_reader)
|
343 |
+
print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: self.resolution_buckets = ", self.resolution_buckets)
|
344 |
+
print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: self.max_num_frames = ", self.max_num_frames)
|
345 |
+
print(f"ImageOrVideoDatasetWithResizeAndRectangleCrop: video_num_frames = ", video_num_frames)
|
346 |
+
|
347 |
+
video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
|
348 |
+
|
349 |
+
if not video_buckets:
|
350 |
+
_, h, w = self.resolution_buckets[0]
|
351 |
+
video_buckets = [(1, h, w)]
|
352 |
+
|
353 |
+
nearest_frame_bucket = min(
|
354 |
+
video_buckets,
|
355 |
+
key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
|
356 |
+
default=video_buckets[0],
|
357 |
+
)[0]
|
358 |
+
|
359 |
+
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
|
360 |
+
|
361 |
+
frames = video_reader.get_batch(frame_indices)
|
362 |
+
frames = frames[:nearest_frame_bucket].float()
|
363 |
+
frames = frames.permute(0, 3, 1, 2).contiguous()
|
364 |
+
|
365 |
+
nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
|
366 |
+
frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
|
367 |
+
frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
|
368 |
+
return frames
|
369 |
+
|
370 |
+
def _find_nearest_resolution(self, height, width):
|
371 |
+
nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
|
372 |
+
return nearest_res[1], nearest_res[2]
|
373 |
+
|
374 |
+
|
375 |
+
class PrecomputedDataset(Dataset):
|
376 |
+
def __init__(self, data_root: str, model_name: str = None, cleaned_model_id: str = None) -> None:
|
377 |
+
super().__init__()
|
378 |
+
|
379 |
+
self.data_root = Path(data_root)
|
380 |
+
|
381 |
+
if model_name and cleaned_model_id:
|
382 |
+
precomputation_dir = self.data_root / f"{model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
|
383 |
+
self.latents_path = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
|
384 |
+
self.conditions_path = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
|
385 |
+
else:
|
386 |
+
self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME
|
387 |
+
self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME
|
388 |
+
|
389 |
+
self.latent_conditions = sorted(os.listdir(self.latents_path))
|
390 |
+
self.text_conditions = sorted(os.listdir(self.conditions_path))
|
391 |
+
|
392 |
+
assert len(self.latent_conditions) == len(self.text_conditions), "Number of captions and videos do not match"
|
393 |
+
|
394 |
+
def __len__(self) -> int:
|
395 |
+
return len(self.latent_conditions)
|
396 |
+
|
397 |
+
def __getitem__(self, index: int) -> Dict[str, Any]:
|
398 |
+
conditions = {}
|
399 |
+
latent_path = self.latents_path / self.latent_conditions[index]
|
400 |
+
condition_path = self.conditions_path / self.text_conditions[index]
|
401 |
+
conditions["latent_conditions"] = torch.load(latent_path, map_location="cpu", weights_only=True)
|
402 |
+
conditions["text_conditions"] = torch.load(condition_path, map_location="cpu", weights_only=True)
|
403 |
+
return conditions
|
404 |
+
|
405 |
+
|
406 |
+
class BucketSampler(Sampler):
|
407 |
+
r"""
|
408 |
+
PyTorch Sampler that groups 3D data by height, width and frames.
|
409 |
+
|
410 |
+
Args:
|
411 |
+
data_source (`ImageOrVideoDataset`):
|
412 |
+
A PyTorch dataset object that is an instance of `ImageOrVideoDataset`.
|
413 |
+
batch_size (`int`, defaults to `8`):
|
414 |
+
The batch size to use for training.
|
415 |
+
shuffle (`bool`, defaults to `True`):
|
416 |
+
Whether or not to shuffle the data in each batch before dispatching to dataloader.
|
417 |
+
drop_last (`bool`, defaults to `False`):
|
418 |
+
Whether or not to drop incomplete buckets of data after completely iterating over all data
|
419 |
+
in the dataset. If set to True, only batches that have `batch_size` number of entries will
|
420 |
+
be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
|
421 |
+
and batches that do not have `batch_size` number of entries will also be yielded.
|
422 |
+
"""
|
423 |
+
|
424 |
+
def __init__(
|
425 |
+
self, data_source: ImageOrVideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
|
426 |
+
) -> None:
|
427 |
+
self.data_source = data_source
|
428 |
+
self.batch_size = batch_size
|
429 |
+
self.shuffle = shuffle
|
430 |
+
self.drop_last = drop_last
|
431 |
+
|
432 |
+
self.buckets = {resolution: [] for resolution in data_source.resolution_buckets}
|
433 |
+
|
434 |
+
self._raised_warning_for_drop_last = False
|
435 |
+
|
436 |
+
def __len__(self):
|
437 |
+
if self.drop_last and not self._raised_warning_for_drop_last:
|
438 |
+
self._raised_warning_for_drop_last = True
|
439 |
+
logger.warning(
|
440 |
+
"Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
|
441 |
+
)
|
442 |
+
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
|
443 |
+
|
444 |
+
def __iter__(self):
|
445 |
+
for index, data in enumerate(self.data_source):
|
446 |
+
video_metadata = data["video_metadata"]
|
447 |
+
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
|
448 |
+
|
449 |
+
self.buckets[(f, h, w)].append(data)
|
450 |
+
if len(self.buckets[(f, h, w)]) == self.batch_size:
|
451 |
+
if self.shuffle:
|
452 |
+
random.shuffle(self.buckets[(f, h, w)])
|
453 |
+
yield self.buckets[(f, h, w)]
|
454 |
+
del self.buckets[(f, h, w)]
|
455 |
+
self.buckets[(f, h, w)] = []
|
456 |
+
|
457 |
+
if self.drop_last:
|
458 |
+
return
|
459 |
+
|
460 |
+
for fhw, bucket in list(self.buckets.items()):
|
461 |
+
if len(bucket) == 0:
|
462 |
+
continue
|
463 |
+
if self.shuffle:
|
464 |
+
random.shuffle(bucket)
|
465 |
+
yield bucket
|
466 |
+
del self.buckets[fhw]
|
467 |
+
self.buckets[fhw] = []
|
finetrainers/hooks/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .layerwise_upcasting import apply_layerwise_upcasting
|
finetrainers/hooks/hooks.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import functools
|
16 |
+
from typing import Any, Dict, Optional, Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from accelerate.logging import get_logger
|
20 |
+
|
21 |
+
from ..constants import FINETRAINERS_LOG_LEVEL
|
22 |
+
|
23 |
+
|
24 |
+
logger = get_logger("finetrainers") # pylint: disable=invalid-name
|
25 |
+
logger.setLevel(FINETRAINERS_LOG_LEVEL)
|
26 |
+
|
27 |
+
|
28 |
+
class ModelHook:
|
29 |
+
r"""
|
30 |
+
A hook that contains callbacks to be executed just before and after the forward method of a model.
|
31 |
+
"""
|
32 |
+
|
33 |
+
_is_stateful = False
|
34 |
+
|
35 |
+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
36 |
+
r"""
|
37 |
+
Hook that is executed when a model is initialized.
|
38 |
+
Args:
|
39 |
+
module (`torch.nn.Module`):
|
40 |
+
The module attached to this hook.
|
41 |
+
"""
|
42 |
+
return module
|
43 |
+
|
44 |
+
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
45 |
+
r"""
|
46 |
+
Hook that is executed when a model is deinitalized.
|
47 |
+
Args:
|
48 |
+
module (`torch.nn.Module`):
|
49 |
+
The module attached to this hook.
|
50 |
+
"""
|
51 |
+
module.forward = module._old_forward
|
52 |
+
del module._old_forward
|
53 |
+
return module
|
54 |
+
|
55 |
+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
56 |
+
r"""
|
57 |
+
Hook that is executed just before the forward method of the model.
|
58 |
+
Args:
|
59 |
+
module (`torch.nn.Module`):
|
60 |
+
The module whose forward pass will be executed just after this event.
|
61 |
+
args (`Tuple[Any]`):
|
62 |
+
The positional arguments passed to the module.
|
63 |
+
kwargs (`Dict[Str, Any]`):
|
64 |
+
The keyword arguments passed to the module.
|
65 |
+
Returns:
|
66 |
+
`Tuple[Tuple[Any], Dict[Str, Any]]`:
|
67 |
+
A tuple with the treated `args` and `kwargs`.
|
68 |
+
"""
|
69 |
+
return args, kwargs
|
70 |
+
|
71 |
+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
|
72 |
+
r"""
|
73 |
+
Hook that is executed just after the forward method of the model.
|
74 |
+
Args:
|
75 |
+
module (`torch.nn.Module`):
|
76 |
+
The module whose forward pass been executed just before this event.
|
77 |
+
output (`Any`):
|
78 |
+
The output of the module.
|
79 |
+
Returns:
|
80 |
+
`Any`: The processed `output`.
|
81 |
+
"""
|
82 |
+
return output
|
83 |
+
|
84 |
+
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
85 |
+
r"""
|
86 |
+
Hook that is executed when the hook is detached from a module.
|
87 |
+
Args:
|
88 |
+
module (`torch.nn.Module`):
|
89 |
+
The module detached from this hook.
|
90 |
+
"""
|
91 |
+
return module
|
92 |
+
|
93 |
+
def reset_state(self, module: torch.nn.Module):
|
94 |
+
if self._is_stateful:
|
95 |
+
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
|
96 |
+
return module
|
97 |
+
|
98 |
+
|
99 |
+
class HookRegistry:
|
100 |
+
def __init__(self, module_ref: torch.nn.Module) -> None:
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
self.hooks: Dict[str, ModelHook] = {}
|
104 |
+
|
105 |
+
self._module_ref = module_ref
|
106 |
+
self._hook_order = []
|
107 |
+
|
108 |
+
def register_hook(self, hook: ModelHook, name: str) -> None:
|
109 |
+
if name in self.hooks.keys():
|
110 |
+
logger.warning(f"Hook with name {name} already exists, replacing it.")
|
111 |
+
|
112 |
+
if hasattr(self._module_ref, "_old_forward"):
|
113 |
+
old_forward = self._module_ref._old_forward
|
114 |
+
else:
|
115 |
+
old_forward = self._module_ref.forward
|
116 |
+
self._module_ref._old_forward = self._module_ref.forward
|
117 |
+
|
118 |
+
self._module_ref = hook.initialize_hook(self._module_ref)
|
119 |
+
|
120 |
+
if hasattr(hook, "new_forward"):
|
121 |
+
rewritten_forward = hook.new_forward
|
122 |
+
|
123 |
+
def new_forward(module, *args, **kwargs):
|
124 |
+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
|
125 |
+
output = rewritten_forward(module, *args, **kwargs)
|
126 |
+
return hook.post_forward(module, output)
|
127 |
+
else:
|
128 |
+
|
129 |
+
def new_forward(module, *args, **kwargs):
|
130 |
+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
|
131 |
+
output = old_forward(*args, **kwargs)
|
132 |
+
return hook.post_forward(module, output)
|
133 |
+
|
134 |
+
self._module_ref.forward = functools.update_wrapper(
|
135 |
+
functools.partial(new_forward, self._module_ref), old_forward
|
136 |
+
)
|
137 |
+
|
138 |
+
self.hooks[name] = hook
|
139 |
+
self._hook_order.append(name)
|
140 |
+
|
141 |
+
def get_hook(self, name: str) -> Optional[ModelHook]:
|
142 |
+
if name not in self.hooks.keys():
|
143 |
+
return None
|
144 |
+
return self.hooks[name]
|
145 |
+
|
146 |
+
def remove_hook(self, name: str) -> None:
|
147 |
+
if name not in self.hooks.keys():
|
148 |
+
raise ValueError(f"Hook with name {name} not found.")
|
149 |
+
self.hooks[name].deinitalize_hook(self._module_ref)
|
150 |
+
del self.hooks[name]
|
151 |
+
self._hook_order.remove(name)
|
152 |
+
|
153 |
+
def reset_stateful_hooks(self, recurse: bool = True) -> None:
|
154 |
+
for hook_name in self._hook_order:
|
155 |
+
hook = self.hooks[hook_name]
|
156 |
+
if hook._is_stateful:
|
157 |
+
hook.reset_state(self._module_ref)
|
158 |
+
|
159 |
+
if recurse:
|
160 |
+
for module in self._module_ref.modules():
|
161 |
+
if hasattr(module, "_diffusers_hook"):
|
162 |
+
module._diffusers_hook.reset_stateful_hooks(recurse=False)
|
163 |
+
|
164 |
+
@classmethod
|
165 |
+
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
|
166 |
+
if not hasattr(module, "_diffusers_hook"):
|
167 |
+
module._diffusers_hook = cls(module)
|
168 |
+
return module._diffusers_hook
|
169 |
+
|
170 |
+
def __repr__(self) -> str:
|
171 |
+
hook_repr = ""
|
172 |
+
for i, hook_name in enumerate(self._hook_order):
|
173 |
+
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
|
174 |
+
if i < len(self._hook_order) - 1:
|
175 |
+
hook_repr += "\n"
|
176 |
+
return f"HookRegistry(\n{hook_repr}\n)"
|
finetrainers/hooks/layerwise_upcasting.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import re
|
16 |
+
from typing import Optional, Tuple, Type
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from accelerate.logging import get_logger
|
20 |
+
|
21 |
+
from ..constants import FINETRAINERS_LOG_LEVEL
|
22 |
+
from .hooks import HookRegistry, ModelHook
|
23 |
+
|
24 |
+
|
25 |
+
logger = get_logger("finetrainers") # pylint: disable=invalid-name
|
26 |
+
logger.setLevel(FINETRAINERS_LOG_LEVEL)
|
27 |
+
|
28 |
+
|
29 |
+
# fmt: off
|
30 |
+
_SUPPORTED_PYTORCH_LAYERS = (
|
31 |
+
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
32 |
+
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
33 |
+
torch.nn.Linear,
|
34 |
+
)
|
35 |
+
|
36 |
+
_DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm")
|
37 |
+
# fmt: on
|
38 |
+
|
39 |
+
|
40 |
+
class LayerwiseUpcastingHook(ModelHook):
|
41 |
+
r"""
|
42 |
+
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
|
43 |
+
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
|
44 |
+
footprint.
|
45 |
+
"""
|
46 |
+
|
47 |
+
_is_stateful = False
|
48 |
+
|
49 |
+
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
|
50 |
+
self.storage_dtype = storage_dtype
|
51 |
+
self.compute_dtype = compute_dtype
|
52 |
+
self.non_blocking = non_blocking
|
53 |
+
|
54 |
+
def initialize_hook(self, module: torch.nn.Module):
|
55 |
+
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
|
56 |
+
return module
|
57 |
+
|
58 |
+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
|
59 |
+
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
|
60 |
+
return args, kwargs
|
61 |
+
|
62 |
+
def post_forward(self, module: torch.nn.Module, output):
|
63 |
+
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
|
64 |
+
return output
|
65 |
+
|
66 |
+
|
67 |
+
def apply_layerwise_upcasting(
|
68 |
+
module: torch.nn.Module,
|
69 |
+
storage_dtype: torch.dtype,
|
70 |
+
compute_dtype: torch.dtype,
|
71 |
+
skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN,
|
72 |
+
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None,
|
73 |
+
non_blocking: bool = False,
|
74 |
+
_prefix: str = "",
|
75 |
+
) -> None:
|
76 |
+
r"""
|
77 |
+
Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
|
78 |
+
nn.Module using diffusers layers or pytorch primitives.
|
79 |
+
Args:
|
80 |
+
module (`torch.nn.Module`):
|
81 |
+
The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
|
82 |
+
precision dtype for storage.
|
83 |
+
storage_dtype (`torch.dtype`):
|
84 |
+
The dtype to cast the module to before/after the forward pass for storage.
|
85 |
+
compute_dtype (`torch.dtype`):
|
86 |
+
The dtype to cast the module to during the forward pass for computation.
|
87 |
+
skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
|
88 |
+
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
|
89 |
+
skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`):
|
90 |
+
A list of module classes to skip during the layerwise upcasting process.
|
91 |
+
non_blocking (`bool`, defaults to `False`):
|
92 |
+
If `True`, the weight casting operations are non-blocking.
|
93 |
+
"""
|
94 |
+
if skip_modules_classes is None and skip_modules_pattern is None:
|
95 |
+
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
|
96 |
+
return
|
97 |
+
|
98 |
+
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
|
99 |
+
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
|
100 |
+
)
|
101 |
+
if should_skip:
|
102 |
+
logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"')
|
103 |
+
return
|
104 |
+
|
105 |
+
if isinstance(module, _SUPPORTED_PYTORCH_LAYERS):
|
106 |
+
logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"')
|
107 |
+
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
|
108 |
+
return
|
109 |
+
|
110 |
+
for name, submodule in module.named_children():
|
111 |
+
layer_name = f"{_prefix}.{name}" if _prefix else name
|
112 |
+
apply_layerwise_upcasting(
|
113 |
+
submodule,
|
114 |
+
storage_dtype,
|
115 |
+
compute_dtype,
|
116 |
+
skip_modules_pattern,
|
117 |
+
skip_modules_classes,
|
118 |
+
non_blocking,
|
119 |
+
_prefix=layer_name,
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
def apply_layerwise_upcasting_hook(
|
124 |
+
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
|
125 |
+
) -> None:
|
126 |
+
r"""
|
127 |
+
Applies a `LayerwiseUpcastingHook` to a given module.
|
128 |
+
Args:
|
129 |
+
module (`torch.nn.Module`):
|
130 |
+
The module to attach the hook to.
|
131 |
+
storage_dtype (`torch.dtype`):
|
132 |
+
The dtype to cast the module to before the forward pass.
|
133 |
+
compute_dtype (`torch.dtype`):
|
134 |
+
The dtype to cast the module to during the forward pass.
|
135 |
+
non_blocking (`bool`):
|
136 |
+
If `True`, the weight casting operations are non-blocking.
|
137 |
+
"""
|
138 |
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
139 |
+
hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking)
|
140 |
+
registry.register_hook(hook, "layerwise_upcasting")
|
finetrainers/models/__init__.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
+
from .cogvideox import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG, COGVIDEOX_T2V_LORA_CONFIG
|
4 |
+
from .hunyuan_video import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, HUNYUAN_VIDEO_T2V_LORA_CONFIG
|
5 |
+
from .ltx_video import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG, LTX_VIDEO_T2V_LORA_CONFIG
|
6 |
+
|
7 |
+
|
8 |
+
SUPPORTED_MODEL_CONFIGS = {
|
9 |
+
"hunyuan_video": {
|
10 |
+
"lora": HUNYUAN_VIDEO_T2V_LORA_CONFIG,
|
11 |
+
"full-finetune": HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG,
|
12 |
+
},
|
13 |
+
"ltx_video": {
|
14 |
+
"lora": LTX_VIDEO_T2V_LORA_CONFIG,
|
15 |
+
"full-finetune": LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG,
|
16 |
+
},
|
17 |
+
"cogvideox": {
|
18 |
+
"lora": COGVIDEOX_T2V_LORA_CONFIG,
|
19 |
+
"full-finetune": COGVIDEOX_T2V_FULL_FINETUNE_CONFIG,
|
20 |
+
},
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
def get_config_from_model_name(model_name: str, training_type: str) -> Dict[str, Any]:
|
25 |
+
if model_name not in SUPPORTED_MODEL_CONFIGS:
|
26 |
+
raise ValueError(
|
27 |
+
f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
|
28 |
+
)
|
29 |
+
if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
|
30 |
+
raise ValueError(
|
31 |
+
f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
|
32 |
+
)
|
33 |
+
return SUPPORTED_MODEL_CONFIGS[model_name][training_type]
|
finetrainers/models/cogvideox/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .full_finetune import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG
|
2 |
+
from .lora import COGVIDEOX_T2V_LORA_CONFIG
|
finetrainers/models/cogvideox/full_finetune.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import CogVideoXPipeline
|
2 |
+
|
3 |
+
from .lora import (
|
4 |
+
calculate_noisy_latents,
|
5 |
+
collate_fn_t2v,
|
6 |
+
forward_pass,
|
7 |
+
initialize_pipeline,
|
8 |
+
load_condition_models,
|
9 |
+
load_diffusion_models,
|
10 |
+
load_latent_models,
|
11 |
+
post_latent_preparation,
|
12 |
+
prepare_conditions,
|
13 |
+
prepare_latents,
|
14 |
+
validation,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
# TODO(aryan): refactor into model specs for better re-use
|
19 |
+
COGVIDEOX_T2V_FULL_FINETUNE_CONFIG = {
|
20 |
+
"pipeline_cls": CogVideoXPipeline,
|
21 |
+
"load_condition_models": load_condition_models,
|
22 |
+
"load_latent_models": load_latent_models,
|
23 |
+
"load_diffusion_models": load_diffusion_models,
|
24 |
+
"initialize_pipeline": initialize_pipeline,
|
25 |
+
"prepare_conditions": prepare_conditions,
|
26 |
+
"prepare_latents": prepare_latents,
|
27 |
+
"post_latent_preparation": post_latent_preparation,
|
28 |
+
"collate_fn": collate_fn_t2v,
|
29 |
+
"calculate_noisy_latents": calculate_noisy_latents,
|
30 |
+
"forward_pass": forward_pass,
|
31 |
+
"validation": validation,
|
32 |
+
}
|
finetrainers/models/cogvideox/lora.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
7 |
+
|
8 |
+
from .utils import prepare_rotary_positional_embeddings
|
9 |
+
|
10 |
+
|
11 |
+
def load_condition_models(
|
12 |
+
model_id: str = "THUDM/CogVideoX-5b",
|
13 |
+
text_encoder_dtype: torch.dtype = torch.bfloat16,
|
14 |
+
revision: Optional[str] = None,
|
15 |
+
cache_dir: Optional[str] = None,
|
16 |
+
**kwargs,
|
17 |
+
):
|
18 |
+
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
|
19 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
20 |
+
model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
|
21 |
+
)
|
22 |
+
return {"tokenizer": tokenizer, "text_encoder": text_encoder}
|
23 |
+
|
24 |
+
|
25 |
+
def load_latent_models(
|
26 |
+
model_id: str = "THUDM/CogVideoX-5b",
|
27 |
+
vae_dtype: torch.dtype = torch.bfloat16,
|
28 |
+
revision: Optional[str] = None,
|
29 |
+
cache_dir: Optional[str] = None,
|
30 |
+
**kwargs,
|
31 |
+
):
|
32 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(
|
33 |
+
model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
|
34 |
+
)
|
35 |
+
return {"vae": vae}
|
36 |
+
|
37 |
+
|
38 |
+
def load_diffusion_models(
|
39 |
+
model_id: str = "THUDM/CogVideoX-5b",
|
40 |
+
transformer_dtype: torch.dtype = torch.bfloat16,
|
41 |
+
revision: Optional[str] = None,
|
42 |
+
cache_dir: Optional[str] = None,
|
43 |
+
**kwargs,
|
44 |
+
):
|
45 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
46 |
+
model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
|
47 |
+
)
|
48 |
+
scheduler = CogVideoXDDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
49 |
+
return {"transformer": transformer, "scheduler": scheduler}
|
50 |
+
|
51 |
+
|
52 |
+
def initialize_pipeline(
|
53 |
+
model_id: str = "THUDM/CogVideoX-5b",
|
54 |
+
text_encoder_dtype: torch.dtype = torch.bfloat16,
|
55 |
+
transformer_dtype: torch.dtype = torch.bfloat16,
|
56 |
+
vae_dtype: torch.dtype = torch.bfloat16,
|
57 |
+
tokenizer: Optional[T5Tokenizer] = None,
|
58 |
+
text_encoder: Optional[T5EncoderModel] = None,
|
59 |
+
transformer: Optional[CogVideoXTransformer3DModel] = None,
|
60 |
+
vae: Optional[AutoencoderKLCogVideoX] = None,
|
61 |
+
scheduler: Optional[CogVideoXDDIMScheduler] = None,
|
62 |
+
device: Optional[torch.device] = None,
|
63 |
+
revision: Optional[str] = None,
|
64 |
+
cache_dir: Optional[str] = None,
|
65 |
+
enable_slicing: bool = False,
|
66 |
+
enable_tiling: bool = False,
|
67 |
+
enable_model_cpu_offload: bool = False,
|
68 |
+
is_training: bool = False,
|
69 |
+
**kwargs,
|
70 |
+
) -> CogVideoXPipeline:
|
71 |
+
component_name_pairs = [
|
72 |
+
("tokenizer", tokenizer),
|
73 |
+
("text_encoder", text_encoder),
|
74 |
+
("transformer", transformer),
|
75 |
+
("vae", vae),
|
76 |
+
("scheduler", scheduler),
|
77 |
+
]
|
78 |
+
components = {}
|
79 |
+
for name, component in component_name_pairs:
|
80 |
+
if component is not None:
|
81 |
+
components[name] = component
|
82 |
+
|
83 |
+
pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
|
84 |
+
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
|
85 |
+
pipe.vae = pipe.vae.to(dtype=vae_dtype)
|
86 |
+
|
87 |
+
# The transformer should already be in the correct dtype when training, so we don't need to cast it here.
|
88 |
+
# If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
|
89 |
+
# DDP optimizer step.
|
90 |
+
if not is_training:
|
91 |
+
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
|
92 |
+
|
93 |
+
if enable_slicing:
|
94 |
+
pipe.vae.enable_slicing()
|
95 |
+
if enable_tiling:
|
96 |
+
pipe.vae.enable_tiling()
|
97 |
+
|
98 |
+
if enable_model_cpu_offload:
|
99 |
+
pipe.enable_model_cpu_offload(device=device)
|
100 |
+
else:
|
101 |
+
pipe.to(device=device)
|
102 |
+
|
103 |
+
return pipe
|
104 |
+
|
105 |
+
|
106 |
+
def prepare_conditions(
|
107 |
+
tokenizer,
|
108 |
+
text_encoder,
|
109 |
+
prompt: Union[str, List[str]],
|
110 |
+
device: Optional[torch.device] = None,
|
111 |
+
dtype: Optional[torch.dtype] = None,
|
112 |
+
max_sequence_length: int = 226, # TODO: this should be configurable
|
113 |
+
**kwargs,
|
114 |
+
):
|
115 |
+
device = device or text_encoder.device
|
116 |
+
dtype = dtype or text_encoder.dtype
|
117 |
+
return _get_t5_prompt_embeds(
|
118 |
+
tokenizer=tokenizer,
|
119 |
+
text_encoder=text_encoder,
|
120 |
+
prompt=prompt,
|
121 |
+
max_sequence_length=max_sequence_length,
|
122 |
+
device=device,
|
123 |
+
dtype=dtype,
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
def prepare_latents(
|
128 |
+
vae: AutoencoderKLCogVideoX,
|
129 |
+
image_or_video: torch.Tensor,
|
130 |
+
device: Optional[torch.device] = None,
|
131 |
+
dtype: Optional[torch.dtype] = None,
|
132 |
+
generator: Optional[torch.Generator] = None,
|
133 |
+
precompute: bool = False,
|
134 |
+
**kwargs,
|
135 |
+
) -> torch.Tensor:
|
136 |
+
device = device or vae.device
|
137 |
+
dtype = dtype or vae.dtype
|
138 |
+
|
139 |
+
if image_or_video.ndim == 4:
|
140 |
+
image_or_video = image_or_video.unsqueeze(2)
|
141 |
+
assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
|
142 |
+
|
143 |
+
image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
|
144 |
+
image_or_video = image_or_video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
145 |
+
if not precompute:
|
146 |
+
latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
|
147 |
+
if not vae.config.invert_scale_latents:
|
148 |
+
latents = latents * vae.config.scaling_factor
|
149 |
+
# For training Cog 1.5, we don't need to handle the scaling factor here.
|
150 |
+
# The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
|
151 |
+
# is probably only needed for image-to-video training.
|
152 |
+
# TODO(aryan): investigate this
|
153 |
+
# else:
|
154 |
+
# latents = 1 / vae.config.scaling_factor * latents
|
155 |
+
latents = latents.to(dtype=dtype)
|
156 |
+
return {"latents": latents}
|
157 |
+
else:
|
158 |
+
# handle vae scaling in the `train()` method directly.
|
159 |
+
if vae.use_slicing and image_or_video.shape[0] > 1:
|
160 |
+
encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
|
161 |
+
h = torch.cat(encoded_slices)
|
162 |
+
else:
|
163 |
+
h = vae._encode(image_or_video)
|
164 |
+
return {"latents": h}
|
165 |
+
|
166 |
+
|
167 |
+
def post_latent_preparation(
|
168 |
+
vae_config: Dict[str, Any], latents: torch.Tensor, patch_size_t: Optional[int] = None, **kwargs
|
169 |
+
) -> torch.Tensor:
|
170 |
+
if not vae_config.invert_scale_latents:
|
171 |
+
latents = latents * vae_config.scaling_factor
|
172 |
+
# For training Cog 1.5, we don't need to handle the scaling factor here.
|
173 |
+
# The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
|
174 |
+
# is probably only needed for image-to-video training.
|
175 |
+
# TODO(aryan): investigate this
|
176 |
+
# else:
|
177 |
+
# latents = 1 / vae_config.scaling_factor * latents
|
178 |
+
latents = _pad_frames(latents, patch_size_t)
|
179 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
180 |
+
return {"latents": latents}
|
181 |
+
|
182 |
+
|
183 |
+
def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
184 |
+
return {
|
185 |
+
"prompts": [x["prompt"] for x in batch[0]],
|
186 |
+
"videos": torch.stack([x["video"] for x in batch[0]]),
|
187 |
+
}
|
188 |
+
|
189 |
+
|
190 |
+
def calculate_noisy_latents(
|
191 |
+
scheduler: CogVideoXDDIMScheduler,
|
192 |
+
noise: torch.Tensor,
|
193 |
+
latents: torch.Tensor,
|
194 |
+
timesteps: torch.LongTensor,
|
195 |
+
) -> torch.Tensor:
|
196 |
+
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
197 |
+
return noisy_latents
|
198 |
+
|
199 |
+
|
200 |
+
def forward_pass(
|
201 |
+
transformer: CogVideoXTransformer3DModel,
|
202 |
+
scheduler: CogVideoXDDIMScheduler,
|
203 |
+
prompt_embeds: torch.Tensor,
|
204 |
+
latents: torch.Tensor,
|
205 |
+
noisy_latents: torch.Tensor,
|
206 |
+
timesteps: torch.LongTensor,
|
207 |
+
ofs_emb: Optional[torch.Tensor] = None,
|
208 |
+
**kwargs,
|
209 |
+
) -> torch.Tensor:
|
210 |
+
# Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself.
|
211 |
+
VAE_SPATIAL_SCALE_FACTOR = 8
|
212 |
+
transformer_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
213 |
+
batch_size, num_frames, num_channels, height, width = noisy_latents.shape
|
214 |
+
rope_base_height = transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR
|
215 |
+
rope_base_width = transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR
|
216 |
+
|
217 |
+
image_rotary_emb = (
|
218 |
+
prepare_rotary_positional_embeddings(
|
219 |
+
height=height * VAE_SPATIAL_SCALE_FACTOR,
|
220 |
+
width=width * VAE_SPATIAL_SCALE_FACTOR,
|
221 |
+
num_frames=num_frames,
|
222 |
+
vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR,
|
223 |
+
patch_size=transformer_config.patch_size,
|
224 |
+
patch_size_t=transformer_config.patch_size_t if hasattr(transformer_config, "patch_size_t") else None,
|
225 |
+
attention_head_dim=transformer_config.attention_head_dim,
|
226 |
+
device=transformer.device,
|
227 |
+
base_height=rope_base_height,
|
228 |
+
base_width=rope_base_width,
|
229 |
+
)
|
230 |
+
if transformer_config.use_rotary_positional_embeddings
|
231 |
+
else None
|
232 |
+
)
|
233 |
+
ofs_emb = None if transformer_config.ofs_embed_dim is None else latents.new_full((batch_size,), fill_value=2.0)
|
234 |
+
|
235 |
+
velocity = transformer(
|
236 |
+
hidden_states=noisy_latents,
|
237 |
+
timestep=timesteps,
|
238 |
+
encoder_hidden_states=prompt_embeds,
|
239 |
+
ofs=ofs_emb,
|
240 |
+
image_rotary_emb=image_rotary_emb,
|
241 |
+
return_dict=False,
|
242 |
+
)[0]
|
243 |
+
# For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same
|
244 |
+
# code paths as scheduler.get_velocity(), which can be confusing to understand.
|
245 |
+
denoised_latents = scheduler.get_velocity(velocity, noisy_latents, timesteps)
|
246 |
+
|
247 |
+
return {"latents": denoised_latents}
|
248 |
+
|
249 |
+
|
250 |
+
def validation(
|
251 |
+
pipeline: CogVideoXPipeline,
|
252 |
+
prompt: str,
|
253 |
+
image: Optional[Image.Image] = None,
|
254 |
+
video: Optional[List[Image.Image]] = None,
|
255 |
+
height: Optional[int] = None,
|
256 |
+
width: Optional[int] = None,
|
257 |
+
num_frames: Optional[int] = None,
|
258 |
+
num_videos_per_prompt: int = 1,
|
259 |
+
generator: Optional[torch.Generator] = None,
|
260 |
+
**kwargs,
|
261 |
+
):
|
262 |
+
generation_kwargs = {
|
263 |
+
"prompt": prompt,
|
264 |
+
"height": height,
|
265 |
+
"width": width,
|
266 |
+
"num_frames": num_frames,
|
267 |
+
"num_videos_per_prompt": num_videos_per_prompt,
|
268 |
+
"generator": generator,
|
269 |
+
"return_dict": True,
|
270 |
+
"output_type": "pil",
|
271 |
+
}
|
272 |
+
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
|
273 |
+
output = pipeline(**generation_kwargs).frames[0]
|
274 |
+
return [("video", output)]
|
275 |
+
|
276 |
+
|
277 |
+
def _get_t5_prompt_embeds(
|
278 |
+
tokenizer: T5Tokenizer,
|
279 |
+
text_encoder: T5EncoderModel,
|
280 |
+
prompt: Union[str, List[str]] = None,
|
281 |
+
max_sequence_length: int = 226,
|
282 |
+
device: Optional[torch.device] = None,
|
283 |
+
dtype: Optional[torch.dtype] = None,
|
284 |
+
):
|
285 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
286 |
+
|
287 |
+
text_inputs = tokenizer(
|
288 |
+
prompt,
|
289 |
+
padding="max_length",
|
290 |
+
max_length=max_sequence_length,
|
291 |
+
truncation=True,
|
292 |
+
add_special_tokens=True,
|
293 |
+
return_tensors="pt",
|
294 |
+
)
|
295 |
+
text_input_ids = text_inputs.input_ids
|
296 |
+
|
297 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
298 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
299 |
+
|
300 |
+
return {"prompt_embeds": prompt_embeds}
|
301 |
+
|
302 |
+
|
303 |
+
def _pad_frames(latents: torch.Tensor, patch_size_t: int):
|
304 |
+
if patch_size_t is None or patch_size_t == 1:
|
305 |
+
return latents
|
306 |
+
|
307 |
+
# `latents` should be of the following format: [B, C, F, H, W].
|
308 |
+
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
309 |
+
latent_num_frames = latents.shape[2]
|
310 |
+
additional_frames = patch_size_t - latent_num_frames % patch_size_t
|
311 |
+
|
312 |
+
if additional_frames > 0:
|
313 |
+
last_frame = latents[:, :, -1:, :, :]
|
314 |
+
padding_frames = last_frame.repeat(1, 1, additional_frames, 1, 1)
|
315 |
+
latents = torch.cat([latents, padding_frames], dim=2)
|
316 |
+
|
317 |
+
return latents
|
318 |
+
|
319 |
+
|
320 |
+
# TODO(aryan): refactor into model specs for better re-use
|
321 |
+
COGVIDEOX_T2V_LORA_CONFIG = {
|
322 |
+
"pipeline_cls": CogVideoXPipeline,
|
323 |
+
"load_condition_models": load_condition_models,
|
324 |
+
"load_latent_models": load_latent_models,
|
325 |
+
"load_diffusion_models": load_diffusion_models,
|
326 |
+
"initialize_pipeline": initialize_pipeline,
|
327 |
+
"prepare_conditions": prepare_conditions,
|
328 |
+
"prepare_latents": prepare_latents,
|
329 |
+
"post_latent_preparation": post_latent_preparation,
|
330 |
+
"collate_fn": collate_fn_t2v,
|
331 |
+
"calculate_noisy_latents": calculate_noisy_latents,
|
332 |
+
"forward_pass": forward_pass,
|
333 |
+
"validation": validation,
|
334 |
+
}
|
finetrainers/models/cogvideox/utils.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
5 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
6 |
+
|
7 |
+
|
8 |
+
def prepare_rotary_positional_embeddings(
|
9 |
+
height: int,
|
10 |
+
width: int,
|
11 |
+
num_frames: int,
|
12 |
+
vae_scale_factor_spatial: int = 8,
|
13 |
+
patch_size: int = 2,
|
14 |
+
patch_size_t: int = None,
|
15 |
+
attention_head_dim: int = 64,
|
16 |
+
device: Optional[torch.device] = None,
|
17 |
+
base_height: int = 480,
|
18 |
+
base_width: int = 720,
|
19 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
20 |
+
grid_height = height // (vae_scale_factor_spatial * patch_size)
|
21 |
+
grid_width = width // (vae_scale_factor_spatial * patch_size)
|
22 |
+
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
|
23 |
+
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
|
24 |
+
|
25 |
+
if patch_size_t is None:
|
26 |
+
# CogVideoX 1.0
|
27 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
28 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
29 |
+
)
|
30 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
31 |
+
embed_dim=attention_head_dim,
|
32 |
+
crops_coords=grid_crops_coords,
|
33 |
+
grid_size=(grid_height, grid_width),
|
34 |
+
temporal_size=num_frames,
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
# CogVideoX 1.5
|
38 |
+
base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
|
39 |
+
|
40 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
41 |
+
embed_dim=attention_head_dim,
|
42 |
+
crops_coords=None,
|
43 |
+
grid_size=(grid_height, grid_width),
|
44 |
+
temporal_size=base_num_frames,
|
45 |
+
grid_type="slice",
|
46 |
+
max_size=(base_size_height, base_size_width),
|
47 |
+
)
|
48 |
+
|
49 |
+
freqs_cos = freqs_cos.to(device=device)
|
50 |
+
freqs_sin = freqs_sin.to(device=device)
|
51 |
+
return freqs_cos, freqs_sin
|
finetrainers/models/hunyuan_video/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .full_finetune import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG
|
2 |
+
from .lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG
|
finetrainers/models/hunyuan_video/full_finetune.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import HunyuanVideoPipeline
|
2 |
+
|
3 |
+
from .lora import (
|
4 |
+
collate_fn_t2v,
|
5 |
+
forward_pass,
|
6 |
+
initialize_pipeline,
|
7 |
+
load_condition_models,
|
8 |
+
load_diffusion_models,
|
9 |
+
load_latent_models,
|
10 |
+
post_latent_preparation,
|
11 |
+
prepare_conditions,
|
12 |
+
prepare_latents,
|
13 |
+
validation,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
# TODO(aryan): refactor into model specs for better re-use
|
18 |
+
HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG = {
|
19 |
+
"pipeline_cls": HunyuanVideoPipeline,
|
20 |
+
"load_condition_models": load_condition_models,
|
21 |
+
"load_latent_models": load_latent_models,
|
22 |
+
"load_diffusion_models": load_diffusion_models,
|
23 |
+
"initialize_pipeline": initialize_pipeline,
|
24 |
+
"prepare_conditions": prepare_conditions,
|
25 |
+
"prepare_latents": prepare_latents,
|
26 |
+
"post_latent_preparation": post_latent_preparation,
|
27 |
+
"collate_fn": collate_fn_t2v,
|
28 |
+
"forward_pass": forward_pass,
|
29 |
+
"validation": validation,
|
30 |
+
}
|
finetrainers/models/hunyuan_video/lora.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from accelerate.logging import get_logger
|
6 |
+
from diffusers import (
|
7 |
+
AutoencoderKLHunyuanVideo,
|
8 |
+
FlowMatchEulerDiscreteScheduler,
|
9 |
+
HunyuanVideoPipeline,
|
10 |
+
HunyuanVideoTransformer3DModel,
|
11 |
+
)
|
12 |
+
from PIL import Image
|
13 |
+
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizer
|
14 |
+
|
15 |
+
|
16 |
+
logger = get_logger("finetrainers") # pylint: disable=invalid-name
|
17 |
+
|
18 |
+
|
19 |
+
def load_condition_models(
|
20 |
+
model_id: str = "hunyuanvideo-community/HunyuanVideo",
|
21 |
+
text_encoder_dtype: torch.dtype = torch.float16,
|
22 |
+
text_encoder_2_dtype: torch.dtype = torch.float16,
|
23 |
+
revision: Optional[str] = None,
|
24 |
+
cache_dir: Optional[str] = None,
|
25 |
+
**kwargs,
|
26 |
+
) -> Dict[str, nn.Module]:
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
|
28 |
+
text_encoder = LlamaModel.from_pretrained(
|
29 |
+
model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
|
30 |
+
)
|
31 |
+
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
32 |
+
model_id, subfolder="tokenizer_2", revision=revision, cache_dir=cache_dir
|
33 |
+
)
|
34 |
+
text_encoder_2 = CLIPTextModel.from_pretrained(
|
35 |
+
model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir
|
36 |
+
)
|
37 |
+
return {
|
38 |
+
"tokenizer": tokenizer,
|
39 |
+
"text_encoder": text_encoder,
|
40 |
+
"tokenizer_2": tokenizer_2,
|
41 |
+
"text_encoder_2": text_encoder_2,
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
def load_latent_models(
|
46 |
+
model_id: str = "hunyuanvideo-community/HunyuanVideo",
|
47 |
+
vae_dtype: torch.dtype = torch.float16,
|
48 |
+
revision: Optional[str] = None,
|
49 |
+
cache_dir: Optional[str] = None,
|
50 |
+
**kwargs,
|
51 |
+
) -> Dict[str, nn.Module]:
|
52 |
+
vae = AutoencoderKLHunyuanVideo.from_pretrained(
|
53 |
+
model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
|
54 |
+
)
|
55 |
+
return {"vae": vae}
|
56 |
+
|
57 |
+
|
58 |
+
def load_diffusion_models(
|
59 |
+
model_id: str = "hunyuanvideo-community/HunyuanVideo",
|
60 |
+
transformer_dtype: torch.dtype = torch.bfloat16,
|
61 |
+
shift: float = 1.0,
|
62 |
+
revision: Optional[str] = None,
|
63 |
+
cache_dir: Optional[str] = None,
|
64 |
+
**kwargs,
|
65 |
+
) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]:
|
66 |
+
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
67 |
+
model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
|
68 |
+
)
|
69 |
+
scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
|
70 |
+
return {"transformer": transformer, "scheduler": scheduler}
|
71 |
+
|
72 |
+
|
73 |
+
def initialize_pipeline(
|
74 |
+
model_id: str = "hunyuanvideo-community/HunyuanVideo",
|
75 |
+
text_encoder_dtype: torch.dtype = torch.float16,
|
76 |
+
text_encoder_2_dtype: torch.dtype = torch.float16,
|
77 |
+
transformer_dtype: torch.dtype = torch.bfloat16,
|
78 |
+
vae_dtype: torch.dtype = torch.float16,
|
79 |
+
tokenizer: Optional[LlamaTokenizer] = None,
|
80 |
+
text_encoder: Optional[LlamaModel] = None,
|
81 |
+
tokenizer_2: Optional[CLIPTokenizer] = None,
|
82 |
+
text_encoder_2: Optional[CLIPTextModel] = None,
|
83 |
+
transformer: Optional[HunyuanVideoTransformer3DModel] = None,
|
84 |
+
vae: Optional[AutoencoderKLHunyuanVideo] = None,
|
85 |
+
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
|
86 |
+
device: Optional[torch.device] = None,
|
87 |
+
revision: Optional[str] = None,
|
88 |
+
cache_dir: Optional[str] = None,
|
89 |
+
enable_slicing: bool = False,
|
90 |
+
enable_tiling: bool = False,
|
91 |
+
enable_model_cpu_offload: bool = False,
|
92 |
+
is_training: bool = False,
|
93 |
+
**kwargs,
|
94 |
+
) -> HunyuanVideoPipeline:
|
95 |
+
component_name_pairs = [
|
96 |
+
("tokenizer", tokenizer),
|
97 |
+
("text_encoder", text_encoder),
|
98 |
+
("tokenizer_2", tokenizer_2),
|
99 |
+
("text_encoder_2", text_encoder_2),
|
100 |
+
("transformer", transformer),
|
101 |
+
("vae", vae),
|
102 |
+
("scheduler", scheduler),
|
103 |
+
]
|
104 |
+
components = {}
|
105 |
+
for name, component in component_name_pairs:
|
106 |
+
if component is not None:
|
107 |
+
components[name] = component
|
108 |
+
|
109 |
+
pipe = HunyuanVideoPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
|
110 |
+
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
|
111 |
+
pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=text_encoder_2_dtype)
|
112 |
+
pipe.vae = pipe.vae.to(dtype=vae_dtype)
|
113 |
+
|
114 |
+
# The transformer should already be in the correct dtype when training, so we don't need to cast it here.
|
115 |
+
# If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
|
116 |
+
# DDP optimizer step.
|
117 |
+
if not is_training:
|
118 |
+
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
|
119 |
+
|
120 |
+
if enable_slicing:
|
121 |
+
pipe.vae.enable_slicing()
|
122 |
+
if enable_tiling:
|
123 |
+
pipe.vae.enable_tiling()
|
124 |
+
|
125 |
+
if enable_model_cpu_offload:
|
126 |
+
pipe.enable_model_cpu_offload(device=device)
|
127 |
+
else:
|
128 |
+
pipe.to(device=device)
|
129 |
+
|
130 |
+
return pipe
|
131 |
+
|
132 |
+
|
133 |
+
def prepare_conditions(
|
134 |
+
tokenizer: LlamaTokenizer,
|
135 |
+
text_encoder: LlamaModel,
|
136 |
+
tokenizer_2: CLIPTokenizer,
|
137 |
+
text_encoder_2: CLIPTextModel,
|
138 |
+
prompt: Union[str, List[str]],
|
139 |
+
guidance: float = 1.0,
|
140 |
+
device: Optional[torch.device] = None,
|
141 |
+
dtype: Optional[torch.dtype] = None,
|
142 |
+
max_sequence_length: int = 256,
|
143 |
+
# TODO(aryan): make configurable
|
144 |
+
prompt_template: Dict[str, Any] = {
|
145 |
+
"template": (
|
146 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
147 |
+
"1. The main content and theme of the video."
|
148 |
+
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
149 |
+
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
150 |
+
"4. background environment, light, style and atmosphere."
|
151 |
+
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
152 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
153 |
+
),
|
154 |
+
"crop_start": 95,
|
155 |
+
},
|
156 |
+
**kwargs,
|
157 |
+
) -> torch.Tensor:
|
158 |
+
device = device or text_encoder.device
|
159 |
+
dtype = dtype or text_encoder.dtype
|
160 |
+
|
161 |
+
if isinstance(prompt, str):
|
162 |
+
prompt = [prompt]
|
163 |
+
|
164 |
+
conditions = {}
|
165 |
+
conditions.update(
|
166 |
+
_get_llama_prompt_embeds(tokenizer, text_encoder, prompt, prompt_template, device, dtype, max_sequence_length)
|
167 |
+
)
|
168 |
+
conditions.update(_get_clip_prompt_embeds(tokenizer_2, text_encoder_2, prompt, device, dtype))
|
169 |
+
|
170 |
+
guidance = torch.tensor([guidance], device=device, dtype=dtype) * 1000.0
|
171 |
+
conditions["guidance"] = guidance
|
172 |
+
|
173 |
+
return conditions
|
174 |
+
|
175 |
+
|
176 |
+
def prepare_latents(
|
177 |
+
vae: AutoencoderKLHunyuanVideo,
|
178 |
+
image_or_video: torch.Tensor,
|
179 |
+
device: Optional[torch.device] = None,
|
180 |
+
dtype: Optional[torch.dtype] = None,
|
181 |
+
generator: Optional[torch.Generator] = None,
|
182 |
+
precompute: bool = False,
|
183 |
+
**kwargs,
|
184 |
+
) -> torch.Tensor:
|
185 |
+
device = device or vae.device
|
186 |
+
dtype = dtype or vae.dtype
|
187 |
+
|
188 |
+
if image_or_video.ndim == 4:
|
189 |
+
image_or_video = image_or_video.unsqueeze(2)
|
190 |
+
assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
|
191 |
+
|
192 |
+
image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
|
193 |
+
image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W]
|
194 |
+
if not precompute:
|
195 |
+
latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
|
196 |
+
latents = latents * vae.config.scaling_factor
|
197 |
+
latents = latents.to(dtype=dtype)
|
198 |
+
return {"latents": latents}
|
199 |
+
else:
|
200 |
+
if vae.use_slicing and image_or_video.shape[0] > 1:
|
201 |
+
encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
|
202 |
+
h = torch.cat(encoded_slices)
|
203 |
+
else:
|
204 |
+
h = vae._encode(image_or_video)
|
205 |
+
return {"latents": h}
|
206 |
+
|
207 |
+
|
208 |
+
def post_latent_preparation(
|
209 |
+
vae_config: Dict[str, Any],
|
210 |
+
latents: torch.Tensor,
|
211 |
+
**kwargs,
|
212 |
+
) -> torch.Tensor:
|
213 |
+
latents = latents * vae_config.scaling_factor
|
214 |
+
return {"latents": latents}
|
215 |
+
|
216 |
+
|
217 |
+
def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
218 |
+
return {
|
219 |
+
"prompts": [x["prompt"] for x in batch[0]],
|
220 |
+
"videos": torch.stack([x["video"] for x in batch[0]]),
|
221 |
+
}
|
222 |
+
|
223 |
+
|
224 |
+
def forward_pass(
|
225 |
+
transformer: HunyuanVideoTransformer3DModel,
|
226 |
+
prompt_embeds: torch.Tensor,
|
227 |
+
pooled_prompt_embeds: torch.Tensor,
|
228 |
+
prompt_attention_mask: torch.Tensor,
|
229 |
+
guidance: torch.Tensor,
|
230 |
+
latents: torch.Tensor,
|
231 |
+
noisy_latents: torch.Tensor,
|
232 |
+
timesteps: torch.LongTensor,
|
233 |
+
**kwargs,
|
234 |
+
) -> torch.Tensor:
|
235 |
+
denoised_latents = transformer(
|
236 |
+
hidden_states=noisy_latents,
|
237 |
+
timestep=timesteps,
|
238 |
+
encoder_hidden_states=prompt_embeds,
|
239 |
+
pooled_projections=pooled_prompt_embeds,
|
240 |
+
encoder_attention_mask=prompt_attention_mask,
|
241 |
+
guidance=guidance,
|
242 |
+
return_dict=False,
|
243 |
+
)[0]
|
244 |
+
|
245 |
+
return {"latents": denoised_latents}
|
246 |
+
|
247 |
+
|
248 |
+
def validation(
|
249 |
+
pipeline: HunyuanVideoPipeline,
|
250 |
+
prompt: str,
|
251 |
+
image: Optional[Image.Image] = None,
|
252 |
+
video: Optional[List[Image.Image]] = None,
|
253 |
+
height: Optional[int] = None,
|
254 |
+
width: Optional[int] = None,
|
255 |
+
num_frames: Optional[int] = None,
|
256 |
+
num_videos_per_prompt: int = 1,
|
257 |
+
generator: Optional[torch.Generator] = None,
|
258 |
+
**kwargs,
|
259 |
+
):
|
260 |
+
generation_kwargs = {
|
261 |
+
"prompt": prompt,
|
262 |
+
"height": height,
|
263 |
+
"width": width,
|
264 |
+
"num_frames": num_frames,
|
265 |
+
"num_inference_steps": 30,
|
266 |
+
"num_videos_per_prompt": num_videos_per_prompt,
|
267 |
+
"generator": generator,
|
268 |
+
"return_dict": True,
|
269 |
+
"output_type": "pil",
|
270 |
+
}
|
271 |
+
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
|
272 |
+
output = pipeline(**generation_kwargs).frames[0]
|
273 |
+
return [("video", output)]
|
274 |
+
|
275 |
+
|
276 |
+
def _get_llama_prompt_embeds(
|
277 |
+
tokenizer: LlamaTokenizer,
|
278 |
+
text_encoder: LlamaModel,
|
279 |
+
prompt: List[str],
|
280 |
+
prompt_template: Dict[str, Any],
|
281 |
+
device: torch.device,
|
282 |
+
dtype: torch.dtype,
|
283 |
+
max_sequence_length: int = 256,
|
284 |
+
num_hidden_layers_to_skip: int = 2,
|
285 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
286 |
+
batch_size = len(prompt)
|
287 |
+
prompt = [prompt_template["template"].format(p) for p in prompt]
|
288 |
+
|
289 |
+
crop_start = prompt_template.get("crop_start", None)
|
290 |
+
if crop_start is None:
|
291 |
+
prompt_template_input = tokenizer(
|
292 |
+
prompt_template["template"],
|
293 |
+
padding="max_length",
|
294 |
+
return_tensors="pt",
|
295 |
+
return_length=False,
|
296 |
+
return_overflowing_tokens=False,
|
297 |
+
return_attention_mask=False,
|
298 |
+
)
|
299 |
+
crop_start = prompt_template_input["input_ids"].shape[-1]
|
300 |
+
# Remove <|eot_id|> token and placeholder {}
|
301 |
+
crop_start -= 2
|
302 |
+
|
303 |
+
max_sequence_length += crop_start
|
304 |
+
text_inputs = tokenizer(
|
305 |
+
prompt,
|
306 |
+
max_length=max_sequence_length,
|
307 |
+
padding="max_length",
|
308 |
+
truncation=True,
|
309 |
+
return_tensors="pt",
|
310 |
+
return_length=False,
|
311 |
+
return_overflowing_tokens=False,
|
312 |
+
return_attention_mask=True,
|
313 |
+
)
|
314 |
+
text_input_ids = text_inputs.input_ids.to(device=device)
|
315 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
|
316 |
+
|
317 |
+
prompt_embeds = text_encoder(
|
318 |
+
input_ids=text_input_ids,
|
319 |
+
attention_mask=prompt_attention_mask,
|
320 |
+
output_hidden_states=True,
|
321 |
+
).hidden_states[-(num_hidden_layers_to_skip + 1)]
|
322 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
323 |
+
|
324 |
+
if crop_start is not None and crop_start > 0:
|
325 |
+
prompt_embeds = prompt_embeds[:, crop_start:]
|
326 |
+
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
|
327 |
+
|
328 |
+
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
329 |
+
|
330 |
+
return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask}
|
331 |
+
|
332 |
+
|
333 |
+
def _get_clip_prompt_embeds(
|
334 |
+
tokenizer_2: CLIPTokenizer,
|
335 |
+
text_encoder_2: CLIPTextModel,
|
336 |
+
prompt: Union[str, List[str]],
|
337 |
+
device: torch.device,
|
338 |
+
dtype: torch.dtype,
|
339 |
+
max_sequence_length: int = 77,
|
340 |
+
) -> torch.Tensor:
|
341 |
+
text_inputs = tokenizer_2(
|
342 |
+
prompt,
|
343 |
+
padding="max_length",
|
344 |
+
max_length=max_sequence_length,
|
345 |
+
truncation=True,
|
346 |
+
return_tensors="pt",
|
347 |
+
)
|
348 |
+
|
349 |
+
prompt_embeds = text_encoder_2(text_inputs.input_ids.to(device), output_hidden_states=False).pooler_output
|
350 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
351 |
+
|
352 |
+
return {"pooled_prompt_embeds": prompt_embeds}
|
353 |
+
|
354 |
+
|
355 |
+
# TODO(aryan): refactor into model specs for better re-use
|
356 |
+
HUNYUAN_VIDEO_T2V_LORA_CONFIG = {
|
357 |
+
"pipeline_cls": HunyuanVideoPipeline,
|
358 |
+
"load_condition_models": load_condition_models,
|
359 |
+
"load_latent_models": load_latent_models,
|
360 |
+
"load_diffusion_models": load_diffusion_models,
|
361 |
+
"initialize_pipeline": initialize_pipeline,
|
362 |
+
"prepare_conditions": prepare_conditions,
|
363 |
+
"prepare_latents": prepare_latents,
|
364 |
+
"post_latent_preparation": post_latent_preparation,
|
365 |
+
"collate_fn": collate_fn_t2v,
|
366 |
+
"forward_pass": forward_pass,
|
367 |
+
"validation": validation,
|
368 |
+
}
|
finetrainers/models/ltx_video/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .full_finetune import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG
|
2 |
+
from .lora import LTX_VIDEO_T2V_LORA_CONFIG
|
finetrainers/models/ltx_video/full_finetune.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import LTXPipeline
|
2 |
+
|
3 |
+
from .lora import (
|
4 |
+
collate_fn_t2v,
|
5 |
+
forward_pass,
|
6 |
+
initialize_pipeline,
|
7 |
+
load_condition_models,
|
8 |
+
load_diffusion_models,
|
9 |
+
load_latent_models,
|
10 |
+
post_latent_preparation,
|
11 |
+
prepare_conditions,
|
12 |
+
prepare_latents,
|
13 |
+
validation,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
# TODO(aryan): refactor into model specs for better re-use
|
18 |
+
LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG = {
|
19 |
+
"pipeline_cls": LTXPipeline,
|
20 |
+
"load_condition_models": load_condition_models,
|
21 |
+
"load_latent_models": load_latent_models,
|
22 |
+
"load_diffusion_models": load_diffusion_models,
|
23 |
+
"initialize_pipeline": initialize_pipeline,
|
24 |
+
"prepare_conditions": prepare_conditions,
|
25 |
+
"prepare_latents": prepare_latents,
|
26 |
+
"post_latent_preparation": post_latent_preparation,
|
27 |
+
"collate_fn": collate_fn_t2v,
|
28 |
+
"forward_pass": forward_pass,
|
29 |
+
"validation": validation,
|
30 |
+
}
|
finetrainers/models/ltx_video/lora.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from accelerate.logging import get_logger
|
6 |
+
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
|
7 |
+
from PIL import Image
|
8 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
9 |
+
|
10 |
+
|
11 |
+
logger = get_logger("finetrainers") # pylint: disable=invalid-name
|
12 |
+
|
13 |
+
|
14 |
+
def load_condition_models(
|
15 |
+
model_id: str = "Lightricks/LTX-Video",
|
16 |
+
text_encoder_dtype: torch.dtype = torch.bfloat16,
|
17 |
+
revision: Optional[str] = None,
|
18 |
+
cache_dir: Optional[str] = None,
|
19 |
+
**kwargs,
|
20 |
+
) -> Dict[str, nn.Module]:
|
21 |
+
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
|
22 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
23 |
+
model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
|
24 |
+
)
|
25 |
+
return {"tokenizer": tokenizer, "text_encoder": text_encoder}
|
26 |
+
|
27 |
+
|
28 |
+
def load_latent_models(
|
29 |
+
model_id: str = "Lightricks/LTX-Video",
|
30 |
+
vae_dtype: torch.dtype = torch.bfloat16,
|
31 |
+
revision: Optional[str] = None,
|
32 |
+
cache_dir: Optional[str] = None,
|
33 |
+
**kwargs,
|
34 |
+
) -> Dict[str, nn.Module]:
|
35 |
+
vae = AutoencoderKLLTXVideo.from_pretrained(
|
36 |
+
model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
|
37 |
+
)
|
38 |
+
return {"vae": vae}
|
39 |
+
|
40 |
+
|
41 |
+
def load_diffusion_models(
|
42 |
+
model_id: str = "Lightricks/LTX-Video",
|
43 |
+
transformer_dtype: torch.dtype = torch.bfloat16,
|
44 |
+
revision: Optional[str] = None,
|
45 |
+
cache_dir: Optional[str] = None,
|
46 |
+
**kwargs,
|
47 |
+
) -> Dict[str, nn.Module]:
|
48 |
+
transformer = LTXVideoTransformer3DModel.from_pretrained(
|
49 |
+
model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
|
50 |
+
)
|
51 |
+
scheduler = FlowMatchEulerDiscreteScheduler()
|
52 |
+
return {"transformer": transformer, "scheduler": scheduler}
|
53 |
+
|
54 |
+
|
55 |
+
def initialize_pipeline(
|
56 |
+
model_id: str = "Lightricks/LTX-Video",
|
57 |
+
text_encoder_dtype: torch.dtype = torch.bfloat16,
|
58 |
+
transformer_dtype: torch.dtype = torch.bfloat16,
|
59 |
+
vae_dtype: torch.dtype = torch.bfloat16,
|
60 |
+
tokenizer: Optional[T5Tokenizer] = None,
|
61 |
+
text_encoder: Optional[T5EncoderModel] = None,
|
62 |
+
transformer: Optional[LTXVideoTransformer3DModel] = None,
|
63 |
+
vae: Optional[AutoencoderKLLTXVideo] = None,
|
64 |
+
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
|
65 |
+
device: Optional[torch.device] = None,
|
66 |
+
revision: Optional[str] = None,
|
67 |
+
cache_dir: Optional[str] = None,
|
68 |
+
enable_slicing: bool = False,
|
69 |
+
enable_tiling: bool = False,
|
70 |
+
enable_model_cpu_offload: bool = False,
|
71 |
+
is_training: bool = False,
|
72 |
+
**kwargs,
|
73 |
+
) -> LTXPipeline:
|
74 |
+
component_name_pairs = [
|
75 |
+
("tokenizer", tokenizer),
|
76 |
+
("text_encoder", text_encoder),
|
77 |
+
("transformer", transformer),
|
78 |
+
("vae", vae),
|
79 |
+
("scheduler", scheduler),
|
80 |
+
]
|
81 |
+
components = {}
|
82 |
+
for name, component in component_name_pairs:
|
83 |
+
if component is not None:
|
84 |
+
components[name] = component
|
85 |
+
|
86 |
+
pipe = LTXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
|
87 |
+
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
|
88 |
+
pipe.vae = pipe.vae.to(dtype=vae_dtype)
|
89 |
+
# The transformer should already be in the correct dtype when training, so we don't need to cast it here.
|
90 |
+
# If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
|
91 |
+
# DDP optimizer step.
|
92 |
+
if not is_training:
|
93 |
+
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
|
94 |
+
|
95 |
+
if enable_slicing:
|
96 |
+
pipe.vae.enable_slicing()
|
97 |
+
if enable_tiling:
|
98 |
+
pipe.vae.enable_tiling()
|
99 |
+
|
100 |
+
if enable_model_cpu_offload:
|
101 |
+
pipe.enable_model_cpu_offload(device=device)
|
102 |
+
else:
|
103 |
+
pipe.to(device=device)
|
104 |
+
|
105 |
+
return pipe
|
106 |
+
|
107 |
+
|
108 |
+
def prepare_conditions(
|
109 |
+
tokenizer: T5Tokenizer,
|
110 |
+
text_encoder: T5EncoderModel,
|
111 |
+
prompt: Union[str, List[str]],
|
112 |
+
device: Optional[torch.device] = None,
|
113 |
+
dtype: Optional[torch.dtype] = None,
|
114 |
+
max_sequence_length: int = 128,
|
115 |
+
**kwargs,
|
116 |
+
) -> torch.Tensor:
|
117 |
+
device = device or text_encoder.device
|
118 |
+
dtype = dtype or text_encoder.dtype
|
119 |
+
|
120 |
+
if isinstance(prompt, str):
|
121 |
+
prompt = [prompt]
|
122 |
+
|
123 |
+
return _encode_prompt_t5(tokenizer, text_encoder, prompt, device, dtype, max_sequence_length)
|
124 |
+
|
125 |
+
|
126 |
+
def prepare_latents(
|
127 |
+
vae: AutoencoderKLLTXVideo,
|
128 |
+
image_or_video: torch.Tensor,
|
129 |
+
patch_size: int = 1,
|
130 |
+
patch_size_t: int = 1,
|
131 |
+
device: Optional[torch.device] = None,
|
132 |
+
dtype: Optional[torch.dtype] = None,
|
133 |
+
generator: Optional[torch.Generator] = None,
|
134 |
+
precompute: bool = False,
|
135 |
+
) -> torch.Tensor:
|
136 |
+
device = device or vae.device
|
137 |
+
|
138 |
+
if image_or_video.ndim == 4:
|
139 |
+
image_or_video = image_or_video.unsqueeze(2)
|
140 |
+
assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
|
141 |
+
|
142 |
+
image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
|
143 |
+
image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W]
|
144 |
+
if not precompute:
|
145 |
+
latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
|
146 |
+
latents = latents.to(dtype=dtype)
|
147 |
+
_, _, num_frames, height, width = latents.shape
|
148 |
+
latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std)
|
149 |
+
latents = _pack_latents(latents, patch_size, patch_size_t)
|
150 |
+
return {"latents": latents, "num_frames": num_frames, "height": height, "width": width}
|
151 |
+
else:
|
152 |
+
if vae.use_slicing and image_or_video.shape[0] > 1:
|
153 |
+
encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
|
154 |
+
h = torch.cat(encoded_slices)
|
155 |
+
else:
|
156 |
+
h = vae._encode(image_or_video)
|
157 |
+
_, _, num_frames, height, width = h.shape
|
158 |
+
|
159 |
+
# TODO(aryan): This is very stupid that we might possibly be storing the latents_mean and latents_std in every file
|
160 |
+
# if precomputation is enabled. We should probably have a single file where re-usable properties like this are stored
|
161 |
+
# so as to reduce the disk memory requirements of the precomputed files.
|
162 |
+
return {
|
163 |
+
"latents": h,
|
164 |
+
"num_frames": num_frames,
|
165 |
+
"height": height,
|
166 |
+
"width": width,
|
167 |
+
"latents_mean": vae.latents_mean,
|
168 |
+
"latents_std": vae.latents_std,
|
169 |
+
}
|
170 |
+
|
171 |
+
|
172 |
+
def post_latent_preparation(
|
173 |
+
latents: torch.Tensor,
|
174 |
+
latents_mean: torch.Tensor,
|
175 |
+
latents_std: torch.Tensor,
|
176 |
+
num_frames: int,
|
177 |
+
height: int,
|
178 |
+
width: int,
|
179 |
+
patch_size: int = 1,
|
180 |
+
patch_size_t: int = 1,
|
181 |
+
**kwargs,
|
182 |
+
) -> torch.Tensor:
|
183 |
+
latents = _normalize_latents(latents, latents_mean, latents_std)
|
184 |
+
latents = _pack_latents(latents, patch_size, patch_size_t)
|
185 |
+
return {"latents": latents, "num_frames": num_frames, "height": height, "width": width}
|
186 |
+
|
187 |
+
|
188 |
+
def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
189 |
+
return {
|
190 |
+
"prompts": [x["prompt"] for x in batch[0]],
|
191 |
+
"videos": torch.stack([x["video"] for x in batch[0]]),
|
192 |
+
}
|
193 |
+
|
194 |
+
|
195 |
+
def forward_pass(
|
196 |
+
transformer: LTXVideoTransformer3DModel,
|
197 |
+
prompt_embeds: torch.Tensor,
|
198 |
+
prompt_attention_mask: torch.Tensor,
|
199 |
+
latents: torch.Tensor,
|
200 |
+
noisy_latents: torch.Tensor,
|
201 |
+
timesteps: torch.LongTensor,
|
202 |
+
num_frames: int,
|
203 |
+
height: int,
|
204 |
+
width: int,
|
205 |
+
**kwargs,
|
206 |
+
) -> torch.Tensor:
|
207 |
+
# TODO(aryan): make configurable
|
208 |
+
frame_rate = 25
|
209 |
+
latent_frame_rate = frame_rate / 8
|
210 |
+
spatial_compression_ratio = 32
|
211 |
+
rope_interpolation_scale = [1 / latent_frame_rate, spatial_compression_ratio, spatial_compression_ratio]
|
212 |
+
|
213 |
+
denoised_latents = transformer(
|
214 |
+
hidden_states=noisy_latents,
|
215 |
+
encoder_hidden_states=prompt_embeds,
|
216 |
+
timestep=timesteps,
|
217 |
+
encoder_attention_mask=prompt_attention_mask,
|
218 |
+
num_frames=num_frames,
|
219 |
+
height=height,
|
220 |
+
width=width,
|
221 |
+
rope_interpolation_scale=rope_interpolation_scale,
|
222 |
+
return_dict=False,
|
223 |
+
)[0]
|
224 |
+
|
225 |
+
return {"latents": denoised_latents}
|
226 |
+
|
227 |
+
|
228 |
+
def validation(
|
229 |
+
pipeline: LTXPipeline,
|
230 |
+
prompt: str,
|
231 |
+
image: Optional[Image.Image] = None,
|
232 |
+
video: Optional[List[Image.Image]] = None,
|
233 |
+
height: Optional[int] = None,
|
234 |
+
width: Optional[int] = None,
|
235 |
+
num_frames: Optional[int] = None,
|
236 |
+
frame_rate: int = 24,
|
237 |
+
num_videos_per_prompt: int = 1,
|
238 |
+
generator: Optional[torch.Generator] = None,
|
239 |
+
**kwargs,
|
240 |
+
):
|
241 |
+
generation_kwargs = {
|
242 |
+
"prompt": prompt,
|
243 |
+
"height": height,
|
244 |
+
"width": width,
|
245 |
+
"num_frames": num_frames,
|
246 |
+
"frame_rate": frame_rate,
|
247 |
+
"num_videos_per_prompt": num_videos_per_prompt,
|
248 |
+
"generator": generator,
|
249 |
+
"return_dict": True,
|
250 |
+
"output_type": "pil",
|
251 |
+
}
|
252 |
+
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
|
253 |
+
video = pipeline(**generation_kwargs).frames[0]
|
254 |
+
return [("video", video)]
|
255 |
+
|
256 |
+
|
257 |
+
def _encode_prompt_t5(
|
258 |
+
tokenizer: T5Tokenizer,
|
259 |
+
text_encoder: T5EncoderModel,
|
260 |
+
prompt: List[str],
|
261 |
+
device: torch.device,
|
262 |
+
dtype: torch.dtype,
|
263 |
+
max_sequence_length,
|
264 |
+
) -> torch.Tensor:
|
265 |
+
batch_size = len(prompt)
|
266 |
+
|
267 |
+
text_inputs = tokenizer(
|
268 |
+
prompt,
|
269 |
+
padding="max_length",
|
270 |
+
max_length=max_sequence_length,
|
271 |
+
truncation=True,
|
272 |
+
add_special_tokens=True,
|
273 |
+
return_tensors="pt",
|
274 |
+
)
|
275 |
+
text_input_ids = text_inputs.input_ids
|
276 |
+
prompt_attention_mask = text_inputs.attention_mask
|
277 |
+
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
278 |
+
|
279 |
+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
280 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
281 |
+
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
282 |
+
|
283 |
+
return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask}
|
284 |
+
|
285 |
+
|
286 |
+
def _normalize_latents(
|
287 |
+
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
288 |
+
) -> torch.Tensor:
|
289 |
+
# Normalize latents across the channel dimension [B, C, F, H, W]
|
290 |
+
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
291 |
+
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
292 |
+
latents = (latents - latents_mean) * scaling_factor / latents_std
|
293 |
+
return latents
|
294 |
+
|
295 |
+
|
296 |
+
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
297 |
+
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
298 |
+
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
|
299 |
+
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
|
300 |
+
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
|
301 |
+
batch_size, num_channels, num_frames, height, width = latents.shape
|
302 |
+
post_patch_num_frames = num_frames // patch_size_t
|
303 |
+
post_patch_height = height // patch_size
|
304 |
+
post_patch_width = width // patch_size
|
305 |
+
latents = latents.reshape(
|
306 |
+
batch_size,
|
307 |
+
-1,
|
308 |
+
post_patch_num_frames,
|
309 |
+
patch_size_t,
|
310 |
+
post_patch_height,
|
311 |
+
patch_size,
|
312 |
+
post_patch_width,
|
313 |
+
patch_size,
|
314 |
+
)
|
315 |
+
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
316 |
+
return latents
|
317 |
+
|
318 |
+
|
319 |
+
LTX_VIDEO_T2V_LORA_CONFIG = {
|
320 |
+
"pipeline_cls": LTXPipeline,
|
321 |
+
"load_condition_models": load_condition_models,
|
322 |
+
"load_latent_models": load_latent_models,
|
323 |
+
"load_diffusion_models": load_diffusion_models,
|
324 |
+
"initialize_pipeline": initialize_pipeline,
|
325 |
+
"prepare_conditions": prepare_conditions,
|
326 |
+
"prepare_latents": prepare_latents,
|
327 |
+
"post_latent_preparation": post_latent_preparation,
|
328 |
+
"collate_fn": collate_fn_t2v,
|
329 |
+
"forward_pass": forward_pass,
|
330 |
+
"validation": validation,
|
331 |
+
}
|
finetrainers/patches.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from accelerate.logging import get_logger
|
5 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
6 |
+
|
7 |
+
from .constants import FINETRAINERS_LOG_LEVEL
|
8 |
+
|
9 |
+
|
10 |
+
logger = get_logger("finetrainers") # pylint: disable=invalid-name
|
11 |
+
logger.setLevel(FINETRAINERS_LOG_LEVEL)
|
12 |
+
|
13 |
+
|
14 |
+
def perform_peft_patches() -> None:
|
15 |
+
_perform_patch_move_adapter_to_device_of_base_layer()
|
16 |
+
|
17 |
+
|
18 |
+
def _perform_patch_move_adapter_to_device_of_base_layer() -> None:
|
19 |
+
# We don't patch the method for torch.float32 and torch.bfloat16 because it is okay to train with them. If the model weights
|
20 |
+
# are in torch.float16, torch.float8_e4m3fn or torch.float8_e5m2, we need to patch this method to avoid conversion of
|
21 |
+
# LoRA weights from higher precision dtype.
|
22 |
+
BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer(
|
23 |
+
BaseTunerLayer._move_adapter_to_device_of_base_layer
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def _patched_move_adapter_to_device_of_base_layer(func) -> None:
|
28 |
+
@functools.wraps(func)
|
29 |
+
def wrapper(self, *args, **kwargs):
|
30 |
+
with DisableTensorToDtype():
|
31 |
+
return func(self, *args, **kwargs)
|
32 |
+
|
33 |
+
return wrapper
|
34 |
+
|
35 |
+
|
36 |
+
class DisableTensorToDtype:
|
37 |
+
def __enter__(self):
|
38 |
+
self.original_to = torch.Tensor.to
|
39 |
+
|
40 |
+
def modified_to(tensor, *args, **kwargs):
|
41 |
+
# remove dtype from args if present
|
42 |
+
args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
|
43 |
+
if "dtype" in kwargs:
|
44 |
+
kwargs.pop("dtype")
|
45 |
+
return self.original_to(tensor, *args, **kwargs)
|
46 |
+
|
47 |
+
torch.Tensor.to = modified_to
|
48 |
+
|
49 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
50 |
+
torch.Tensor.to = self.original_to
|
finetrainers/state.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from accelerate import Accelerator
|
3 |
+
|
4 |
+
|
5 |
+
class State:
|
6 |
+
# Training state
|
7 |
+
seed: int = None
|
8 |
+
model_name: str = None
|
9 |
+
accelerator: Accelerator = None
|
10 |
+
weight_dtype: torch.dtype = None
|
11 |
+
train_epochs: int = None
|
12 |
+
train_steps: int = None
|
13 |
+
overwrote_max_train_steps: bool = False
|
14 |
+
num_trainable_parameters: int = 0
|
15 |
+
learning_rate: float = None
|
16 |
+
train_batch_size: int = None
|
17 |
+
generator: torch.Generator = None
|
18 |
+
num_update_steps_per_epoch: int = None
|
19 |
+
|
20 |
+
# Hub state
|
21 |
+
repo_id: str = None
|
22 |
+
|
23 |
+
# Artifacts state
|
24 |
+
output_dir: str = None
|
finetrainers/trainer.py
ADDED
@@ -0,0 +1,1207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
from datetime import datetime, timedelta
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any, Dict, List
|
9 |
+
|
10 |
+
import diffusers
|
11 |
+
import torch
|
12 |
+
import torch.backends
|
13 |
+
import transformers
|
14 |
+
import wandb
|
15 |
+
from accelerate import Accelerator, DistributedType
|
16 |
+
from accelerate.logging import get_logger
|
17 |
+
from accelerate.utils import (
|
18 |
+
DistributedDataParallelKwargs,
|
19 |
+
InitProcessGroupKwargs,
|
20 |
+
ProjectConfiguration,
|
21 |
+
gather_object,
|
22 |
+
set_seed,
|
23 |
+
)
|
24 |
+
from diffusers import DiffusionPipeline
|
25 |
+
from diffusers.configuration_utils import FrozenDict
|
26 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
27 |
+
from diffusers.optimization import get_scheduler
|
28 |
+
from diffusers.training_utils import cast_training_params
|
29 |
+
from diffusers.utils import export_to_video, load_image, load_video
|
30 |
+
from huggingface_hub import create_repo, upload_folder
|
31 |
+
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
32 |
+
from tqdm import tqdm
|
33 |
+
|
34 |
+
from .args import Args, validate_args
|
35 |
+
from .constants import (
|
36 |
+
FINETRAINERS_LOG_LEVEL,
|
37 |
+
PRECOMPUTED_CONDITIONS_DIR_NAME,
|
38 |
+
PRECOMPUTED_DIR_NAME,
|
39 |
+
PRECOMPUTED_LATENTS_DIR_NAME,
|
40 |
+
)
|
41 |
+
from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset
|
42 |
+
from .hooks import apply_layerwise_upcasting
|
43 |
+
from .models import get_config_from_model_name
|
44 |
+
from .patches import perform_peft_patches
|
45 |
+
from .state import State
|
46 |
+
from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from
|
47 |
+
from .utils.data_utils import should_perform_precomputation
|
48 |
+
from .utils.diffusion_utils import (
|
49 |
+
get_scheduler_alphas,
|
50 |
+
get_scheduler_sigmas,
|
51 |
+
prepare_loss_weights,
|
52 |
+
prepare_sigmas,
|
53 |
+
prepare_target,
|
54 |
+
)
|
55 |
+
from .utils.file_utils import string_to_filename
|
56 |
+
from .utils.hub_utils import save_model_card
|
57 |
+
from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous
|
58 |
+
from .utils.model_utils import resolve_vae_cls_from_ckpt_path
|
59 |
+
from .utils.optimizer_utils import get_optimizer
|
60 |
+
from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model
|
61 |
+
|
62 |
+
|
63 |
+
logger = get_logger("finetrainers")
|
64 |
+
logger.setLevel(FINETRAINERS_LOG_LEVEL)
|
65 |
+
|
66 |
+
|
67 |
+
class Trainer:
|
68 |
+
def __init__(self, args: Args) -> None:
|
69 |
+
validate_args(args)
|
70 |
+
|
71 |
+
self.args = args
|
72 |
+
self.args.seed = self.args.seed or datetime.now().year
|
73 |
+
self.state = State()
|
74 |
+
|
75 |
+
# Tokenizers
|
76 |
+
self.tokenizer = None
|
77 |
+
self.tokenizer_2 = None
|
78 |
+
self.tokenizer_3 = None
|
79 |
+
|
80 |
+
# Text encoders
|
81 |
+
self.text_encoder = None
|
82 |
+
self.text_encoder_2 = None
|
83 |
+
self.text_encoder_3 = None
|
84 |
+
|
85 |
+
# Denoisers
|
86 |
+
self.transformer = None
|
87 |
+
self.unet = None
|
88 |
+
|
89 |
+
# Autoencoders
|
90 |
+
self.vae = None
|
91 |
+
|
92 |
+
# Scheduler
|
93 |
+
self.scheduler = None
|
94 |
+
|
95 |
+
self.transformer_config = None
|
96 |
+
self.vae_config = None
|
97 |
+
|
98 |
+
self._init_distributed()
|
99 |
+
self._init_logging()
|
100 |
+
self._init_directories_and_repositories()
|
101 |
+
self._init_config_options()
|
102 |
+
|
103 |
+
# Peform any patches needed for training
|
104 |
+
if len(self.args.layerwise_upcasting_modules) > 0:
|
105 |
+
perform_peft_patches()
|
106 |
+
# TODO(aryan): handle text encoders
|
107 |
+
# if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]):
|
108 |
+
# perform_text_encoder_patches()
|
109 |
+
|
110 |
+
self.state.model_name = self.args.model_name
|
111 |
+
self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type)
|
112 |
+
|
113 |
+
def prepare_dataset(self) -> None:
|
114 |
+
# TODO(aryan): Make a background process for fetching
|
115 |
+
logger.info("Initializing dataset and dataloader")
|
116 |
+
|
117 |
+
self.dataset = ImageOrVideoDatasetWithResizing(
|
118 |
+
data_root=self.args.data_root,
|
119 |
+
caption_column=self.args.caption_column,
|
120 |
+
video_column=self.args.video_column,
|
121 |
+
resolution_buckets=self.args.video_resolution_buckets,
|
122 |
+
dataset_file=self.args.dataset_file,
|
123 |
+
id_token=self.args.id_token,
|
124 |
+
remove_llm_prefixes=self.args.remove_common_llm_caption_prefixes,
|
125 |
+
)
|
126 |
+
self.dataloader = torch.utils.data.DataLoader(
|
127 |
+
self.dataset,
|
128 |
+
batch_size=1,
|
129 |
+
sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True),
|
130 |
+
collate_fn=self.model_config.get("collate_fn"),
|
131 |
+
num_workers=self.args.dataloader_num_workers,
|
132 |
+
pin_memory=self.args.pin_memory,
|
133 |
+
)
|
134 |
+
|
135 |
+
def prepare_models(self) -> None:
|
136 |
+
logger.info("Initializing models")
|
137 |
+
|
138 |
+
load_components_kwargs = self._get_load_components_kwargs()
|
139 |
+
condition_components, latent_components, diffusion_components = {}, {}, {}
|
140 |
+
if not self.args.precompute_conditions:
|
141 |
+
# To download the model files first on the main process (if not already present)
|
142 |
+
# and then load the cached files afterward from the other processes.
|
143 |
+
with self.state.accelerator.main_process_first():
|
144 |
+
condition_components = self.model_config["load_condition_models"](**load_components_kwargs)
|
145 |
+
latent_components = self.model_config["load_latent_models"](**load_components_kwargs)
|
146 |
+
diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs)
|
147 |
+
|
148 |
+
components = {}
|
149 |
+
components.update(condition_components)
|
150 |
+
components.update(latent_components)
|
151 |
+
components.update(diffusion_components)
|
152 |
+
self._set_components(components)
|
153 |
+
|
154 |
+
if self.vae is not None:
|
155 |
+
if self.args.enable_slicing:
|
156 |
+
self.vae.enable_slicing()
|
157 |
+
if self.args.enable_tiling:
|
158 |
+
self.vae.enable_tiling()
|
159 |
+
|
160 |
+
def prepare_precomputations(self) -> None:
|
161 |
+
if not self.args.precompute_conditions:
|
162 |
+
return
|
163 |
+
|
164 |
+
logger.info("Initializing precomputations")
|
165 |
+
|
166 |
+
if self.args.batch_size != 1:
|
167 |
+
raise ValueError("Precomputation is only supported with batch size 1. This will be supported in future.")
|
168 |
+
|
169 |
+
def collate_fn(batch):
|
170 |
+
latent_conditions = [x["latent_conditions"] for x in batch]
|
171 |
+
text_conditions = [x["text_conditions"] for x in batch]
|
172 |
+
batched_latent_conditions = {}
|
173 |
+
batched_text_conditions = {}
|
174 |
+
for key in list(latent_conditions[0].keys()):
|
175 |
+
if torch.is_tensor(latent_conditions[0][key]):
|
176 |
+
batched_latent_conditions[key] = torch.cat([x[key] for x in latent_conditions], dim=0)
|
177 |
+
else:
|
178 |
+
# TODO(aryan): implement batch sampler for precomputed latents
|
179 |
+
batched_latent_conditions[key] = [x[key] for x in latent_conditions][0]
|
180 |
+
for key in list(text_conditions[0].keys()):
|
181 |
+
if torch.is_tensor(text_conditions[0][key]):
|
182 |
+
batched_text_conditions[key] = torch.cat([x[key] for x in text_conditions], dim=0)
|
183 |
+
else:
|
184 |
+
# TODO(aryan): implement batch sampler for precomputed latents
|
185 |
+
batched_text_conditions[key] = [x[key] for x in text_conditions][0]
|
186 |
+
return {"latent_conditions": batched_latent_conditions, "text_conditions": batched_text_conditions}
|
187 |
+
|
188 |
+
cleaned_model_id = string_to_filename(self.args.pretrained_model_name_or_path)
|
189 |
+
precomputation_dir = (
|
190 |
+
Path(self.args.data_root) / f"{self.args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
|
191 |
+
)
|
192 |
+
should_precompute = should_perform_precomputation(precomputation_dir)
|
193 |
+
if not should_precompute:
|
194 |
+
logger.info("Precomputed conditions and latents found. Loading precomputed data.")
|
195 |
+
self.dataloader = torch.utils.data.DataLoader(
|
196 |
+
PrecomputedDataset(
|
197 |
+
data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
|
198 |
+
),
|
199 |
+
batch_size=self.args.batch_size,
|
200 |
+
shuffle=True,
|
201 |
+
collate_fn=collate_fn,
|
202 |
+
num_workers=self.args.dataloader_num_workers,
|
203 |
+
pin_memory=self.args.pin_memory,
|
204 |
+
)
|
205 |
+
return
|
206 |
+
|
207 |
+
logger.info("Precomputed conditions and latents not found. Running precomputation.")
|
208 |
+
|
209 |
+
# At this point, no models are loaded, so we need to load and precompute conditions and latents
|
210 |
+
with self.state.accelerator.main_process_first():
|
211 |
+
condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs())
|
212 |
+
self._set_components(condition_components)
|
213 |
+
self._move_components_to_device()
|
214 |
+
self._disable_grad_for_components([self.text_encoder, self.text_encoder_2, self.text_encoder_3])
|
215 |
+
|
216 |
+
if self.args.caption_dropout_p > 0 and self.args.caption_dropout_technique == "empty":
|
217 |
+
logger.warning(
|
218 |
+
"Caption dropout is not supported with precomputation yet. This will be supported in the future."
|
219 |
+
)
|
220 |
+
|
221 |
+
conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
|
222 |
+
latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
|
223 |
+
conditions_dir.mkdir(parents=True, exist_ok=True)
|
224 |
+
latents_dir.mkdir(parents=True, exist_ok=True)
|
225 |
+
|
226 |
+
accelerator = self.state.accelerator
|
227 |
+
|
228 |
+
# Precompute conditions
|
229 |
+
progress_bar = tqdm(
|
230 |
+
range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
|
231 |
+
desc="Precomputing conditions",
|
232 |
+
disable=not accelerator.is_local_main_process,
|
233 |
+
)
|
234 |
+
index = 0
|
235 |
+
for i, data in enumerate(self.dataset):
|
236 |
+
if i % accelerator.num_processes != accelerator.process_index:
|
237 |
+
continue
|
238 |
+
|
239 |
+
logger.debug(
|
240 |
+
f"Precomputing conditions for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
|
241 |
+
)
|
242 |
+
|
243 |
+
text_conditions = self.model_config["prepare_conditions"](
|
244 |
+
tokenizer=self.tokenizer,
|
245 |
+
tokenizer_2=self.tokenizer_2,
|
246 |
+
tokenizer_3=self.tokenizer_3,
|
247 |
+
text_encoder=self.text_encoder,
|
248 |
+
text_encoder_2=self.text_encoder_2,
|
249 |
+
text_encoder_3=self.text_encoder_3,
|
250 |
+
prompt=data["prompt"],
|
251 |
+
device=accelerator.device,
|
252 |
+
dtype=self.args.transformer_dtype,
|
253 |
+
)
|
254 |
+
filename = conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt"
|
255 |
+
torch.save(text_conditions, filename.as_posix())
|
256 |
+
index += 1
|
257 |
+
progress_bar.update(1)
|
258 |
+
self._delete_components()
|
259 |
+
|
260 |
+
memory_statistics = get_memory_statistics()
|
261 |
+
logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}")
|
262 |
+
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
263 |
+
|
264 |
+
# Precompute latents
|
265 |
+
with self.state.accelerator.main_process_first():
|
266 |
+
latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs())
|
267 |
+
self._set_components(latent_components)
|
268 |
+
self._move_components_to_device()
|
269 |
+
self._disable_grad_for_components([self.vae])
|
270 |
+
|
271 |
+
if self.vae is not None:
|
272 |
+
if self.args.enable_slicing:
|
273 |
+
self.vae.enable_slicing()
|
274 |
+
if self.args.enable_tiling:
|
275 |
+
self.vae.enable_tiling()
|
276 |
+
|
277 |
+
progress_bar = tqdm(
|
278 |
+
range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
|
279 |
+
desc="Precomputing latents",
|
280 |
+
disable=not accelerator.is_local_main_process,
|
281 |
+
)
|
282 |
+
index = 0
|
283 |
+
for i, data in enumerate(self.dataset):
|
284 |
+
if i % accelerator.num_processes != accelerator.process_index:
|
285 |
+
continue
|
286 |
+
|
287 |
+
logger.debug(
|
288 |
+
f"Precomputing latents for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
|
289 |
+
)
|
290 |
+
|
291 |
+
latent_conditions = self.model_config["prepare_latents"](
|
292 |
+
vae=self.vae,
|
293 |
+
image_or_video=data["video"].unsqueeze(0),
|
294 |
+
device=accelerator.device,
|
295 |
+
dtype=self.args.transformer_dtype,
|
296 |
+
generator=self.state.generator,
|
297 |
+
precompute=True,
|
298 |
+
)
|
299 |
+
filename = latents_dir / f"latents-{accelerator.process_index}-{index}.pt"
|
300 |
+
torch.save(latent_conditions, filename.as_posix())
|
301 |
+
index += 1
|
302 |
+
progress_bar.update(1)
|
303 |
+
self._delete_components()
|
304 |
+
|
305 |
+
accelerator.wait_for_everyone()
|
306 |
+
logger.info("Precomputation complete")
|
307 |
+
|
308 |
+
memory_statistics = get_memory_statistics()
|
309 |
+
logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}")
|
310 |
+
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
311 |
+
|
312 |
+
# Update dataloader to use precomputed conditions and latents
|
313 |
+
self.dataloader = torch.utils.data.DataLoader(
|
314 |
+
PrecomputedDataset(
|
315 |
+
data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
|
316 |
+
),
|
317 |
+
batch_size=self.args.batch_size,
|
318 |
+
shuffle=True,
|
319 |
+
collate_fn=collate_fn,
|
320 |
+
num_workers=self.args.dataloader_num_workers,
|
321 |
+
pin_memory=self.args.pin_memory,
|
322 |
+
)
|
323 |
+
|
324 |
+
def prepare_trainable_parameters(self) -> None:
|
325 |
+
logger.info("Initializing trainable parameters")
|
326 |
+
|
327 |
+
with self.state.accelerator.main_process_first():
|
328 |
+
diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs())
|
329 |
+
self._set_components(diffusion_components)
|
330 |
+
|
331 |
+
components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.vae]
|
332 |
+
self._disable_grad_for_components(components)
|
333 |
+
|
334 |
+
if self.args.training_type == "full-finetune":
|
335 |
+
logger.info("Finetuning transformer with no additional parameters")
|
336 |
+
self._enable_grad_for_components([self.transformer])
|
337 |
+
else:
|
338 |
+
logger.info("Finetuning transformer with PEFT parameters")
|
339 |
+
self._disable_grad_for_components([self.transformer])
|
340 |
+
|
341 |
+
# Layerwise upcasting must be applied before adding the LoRA adapter.
|
342 |
+
# If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on
|
343 |
+
# CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly.
|
344 |
+
if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules:
|
345 |
+
apply_layerwise_upcasting(
|
346 |
+
self.transformer,
|
347 |
+
storage_dtype=self.args.layerwise_upcasting_storage_dtype,
|
348 |
+
compute_dtype=self.args.transformer_dtype,
|
349 |
+
skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern,
|
350 |
+
non_blocking=True,
|
351 |
+
)
|
352 |
+
|
353 |
+
self._move_components_to_device()
|
354 |
+
|
355 |
+
if self.args.gradient_checkpointing:
|
356 |
+
self.transformer.enable_gradient_checkpointing()
|
357 |
+
|
358 |
+
if self.args.training_type == "lora":
|
359 |
+
transformer_lora_config = LoraConfig(
|
360 |
+
r=self.args.rank,
|
361 |
+
lora_alpha=self.args.lora_alpha,
|
362 |
+
init_lora_weights=True,
|
363 |
+
target_modules=self.args.target_modules,
|
364 |
+
)
|
365 |
+
self.transformer.add_adapter(transformer_lora_config)
|
366 |
+
else:
|
367 |
+
transformer_lora_config = None
|
368 |
+
|
369 |
+
# TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32
|
370 |
+
# even if layerwise upcasting. Would be nice to have a test as well
|
371 |
+
|
372 |
+
self.register_saving_loading_hooks(transformer_lora_config)
|
373 |
+
|
374 |
+
def register_saving_loading_hooks(self, transformer_lora_config):
|
375 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
376 |
+
def save_model_hook(models, weights, output_dir):
|
377 |
+
if self.state.accelerator.is_main_process:
|
378 |
+
transformer_lora_layers_to_save = None
|
379 |
+
|
380 |
+
for model in models:
|
381 |
+
if isinstance(
|
382 |
+
unwrap_model(self.state.accelerator, model),
|
383 |
+
type(unwrap_model(self.state.accelerator, self.transformer)),
|
384 |
+
):
|
385 |
+
model = unwrap_model(self.state.accelerator, model)
|
386 |
+
if self.args.training_type == "lora":
|
387 |
+
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
388 |
+
else:
|
389 |
+
raise ValueError(f"Unexpected save model: {model.__class__}")
|
390 |
+
|
391 |
+
# make sure to pop weight so that corresponding model is not saved again
|
392 |
+
if weights:
|
393 |
+
weights.pop()
|
394 |
+
|
395 |
+
if self.args.training_type == "lora":
|
396 |
+
self.model_config["pipeline_cls"].save_lora_weights(
|
397 |
+
output_dir,
|
398 |
+
transformer_lora_layers=transformer_lora_layers_to_save,
|
399 |
+
)
|
400 |
+
else:
|
401 |
+
model.save_pretrained(os.path.join(output_dir, "transformer"))
|
402 |
+
|
403 |
+
# In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
|
404 |
+
# to able to load all diffusion components from a specific checkpoint folder during validation, we need to
|
405 |
+
# ensure the scheduler config is serialized as well.
|
406 |
+
self.scheduler.save_pretrained(os.path.join(output_dir, "scheduler"))
|
407 |
+
|
408 |
+
def load_model_hook(models, input_dir):
|
409 |
+
if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED:
|
410 |
+
while len(models) > 0:
|
411 |
+
model = models.pop()
|
412 |
+
if isinstance(
|
413 |
+
unwrap_model(self.state.accelerator, model),
|
414 |
+
type(unwrap_model(self.state.accelerator, self.transformer)),
|
415 |
+
):
|
416 |
+
transformer_ = unwrap_model(self.state.accelerator, model)
|
417 |
+
else:
|
418 |
+
raise ValueError(
|
419 |
+
f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}"
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
transformer_cls_ = unwrap_model(self.state.accelerator, self.transformer).__class__
|
423 |
+
|
424 |
+
if self.args.training_type == "lora":
|
425 |
+
transformer_ = transformer_cls_.from_pretrained(
|
426 |
+
self.args.pretrained_model_name_or_path, subfolder="transformer"
|
427 |
+
)
|
428 |
+
transformer_.add_adapter(transformer_lora_config)
|
429 |
+
lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir)
|
430 |
+
transformer_state_dict = {
|
431 |
+
f'{k.replace("transformer.", "")}': v
|
432 |
+
for k, v in lora_state_dict.items()
|
433 |
+
if k.startswith("transformer.")
|
434 |
+
}
|
435 |
+
incompatible_keys = set_peft_model_state_dict(
|
436 |
+
transformer_, transformer_state_dict, adapter_name="default"
|
437 |
+
)
|
438 |
+
if incompatible_keys is not None:
|
439 |
+
# check only for unexpected keys
|
440 |
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
441 |
+
if unexpected_keys:
|
442 |
+
logger.warning(
|
443 |
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
444 |
+
f" {unexpected_keys}. "
|
445 |
+
)
|
446 |
+
else:
|
447 |
+
transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer"))
|
448 |
+
|
449 |
+
self.state.accelerator.register_save_state_pre_hook(save_model_hook)
|
450 |
+
self.state.accelerator.register_load_state_pre_hook(load_model_hook)
|
451 |
+
|
452 |
+
def prepare_optimizer(self) -> None:
|
453 |
+
logger.info("Initializing optimizer and lr scheduler")
|
454 |
+
|
455 |
+
self.state.train_epochs = self.args.train_epochs
|
456 |
+
self.state.train_steps = self.args.train_steps
|
457 |
+
|
458 |
+
# Make sure the trainable params are in float32
|
459 |
+
if self.args.training_type == "lora":
|
460 |
+
cast_training_params([self.transformer], dtype=torch.float32)
|
461 |
+
|
462 |
+
self.state.learning_rate = self.args.lr
|
463 |
+
if self.args.scale_lr:
|
464 |
+
self.state.learning_rate = (
|
465 |
+
self.state.learning_rate
|
466 |
+
* self.args.gradient_accumulation_steps
|
467 |
+
* self.args.batch_size
|
468 |
+
* self.state.accelerator.num_processes
|
469 |
+
)
|
470 |
+
|
471 |
+
transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, self.transformer.parameters()))
|
472 |
+
transformer_parameters_with_lr = {
|
473 |
+
"params": transformer_trainable_parameters,
|
474 |
+
"lr": self.state.learning_rate,
|
475 |
+
}
|
476 |
+
params_to_optimize = [transformer_parameters_with_lr]
|
477 |
+
self.state.num_trainable_parameters = sum(p.numel() for p in transformer_trainable_parameters)
|
478 |
+
|
479 |
+
use_deepspeed_opt = (
|
480 |
+
self.state.accelerator.state.deepspeed_plugin is not None
|
481 |
+
and "optimizer" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
|
482 |
+
)
|
483 |
+
optimizer = get_optimizer(
|
484 |
+
params_to_optimize=params_to_optimize,
|
485 |
+
optimizer_name=self.args.optimizer,
|
486 |
+
learning_rate=self.state.learning_rate,
|
487 |
+
beta1=self.args.beta1,
|
488 |
+
beta2=self.args.beta2,
|
489 |
+
beta3=self.args.beta3,
|
490 |
+
epsilon=self.args.epsilon,
|
491 |
+
weight_decay=self.args.weight_decay,
|
492 |
+
use_8bit=self.args.use_8bit_bnb,
|
493 |
+
use_deepspeed=use_deepspeed_opt,
|
494 |
+
)
|
495 |
+
|
496 |
+
num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
|
497 |
+
if self.state.train_steps is None:
|
498 |
+
self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
|
499 |
+
self.state.overwrote_max_train_steps = True
|
500 |
+
|
501 |
+
use_deepspeed_lr_scheduler = (
|
502 |
+
self.state.accelerator.state.deepspeed_plugin is not None
|
503 |
+
and "scheduler" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
|
504 |
+
)
|
505 |
+
total_training_steps = self.state.train_steps * self.state.accelerator.num_processes
|
506 |
+
num_warmup_steps = self.args.lr_warmup_steps * self.state.accelerator.num_processes
|
507 |
+
|
508 |
+
if use_deepspeed_lr_scheduler:
|
509 |
+
from accelerate.utils import DummyScheduler
|
510 |
+
|
511 |
+
lr_scheduler = DummyScheduler(
|
512 |
+
name=self.args.lr_scheduler,
|
513 |
+
optimizer=optimizer,
|
514 |
+
total_num_steps=total_training_steps,
|
515 |
+
num_warmup_steps=num_warmup_steps,
|
516 |
+
)
|
517 |
+
else:
|
518 |
+
lr_scheduler = get_scheduler(
|
519 |
+
name=self.args.lr_scheduler,
|
520 |
+
optimizer=optimizer,
|
521 |
+
num_warmup_steps=num_warmup_steps,
|
522 |
+
num_training_steps=total_training_steps,
|
523 |
+
num_cycles=self.args.lr_num_cycles,
|
524 |
+
power=self.args.lr_power,
|
525 |
+
)
|
526 |
+
|
527 |
+
self.optimizer = optimizer
|
528 |
+
self.lr_scheduler = lr_scheduler
|
529 |
+
|
530 |
+
def prepare_for_training(self) -> None:
|
531 |
+
self.transformer, self.optimizer, self.dataloader, self.lr_scheduler = self.state.accelerator.prepare(
|
532 |
+
self.transformer, self.optimizer, self.dataloader, self.lr_scheduler
|
533 |
+
)
|
534 |
+
|
535 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
536 |
+
num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
|
537 |
+
if self.state.overwrote_max_train_steps:
|
538 |
+
self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
|
539 |
+
# Afterwards we recalculate our number of training epochs
|
540 |
+
self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch)
|
541 |
+
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
|
542 |
+
|
543 |
+
def prepare_trackers(self) -> None:
|
544 |
+
logger.info("Initializing trackers")
|
545 |
+
|
546 |
+
tracker_name = self.args.tracker_name or "finetrainers-experiment"
|
547 |
+
self.state.accelerator.init_trackers(tracker_name, config=self._get_training_info())
|
548 |
+
|
549 |
+
def train(self) -> None:
|
550 |
+
logger.info("Starting training")
|
551 |
+
|
552 |
+
memory_statistics = get_memory_statistics()
|
553 |
+
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
|
554 |
+
|
555 |
+
if self.vae_config is None:
|
556 |
+
# If we've precomputed conditions and latents already, and are now re-using it, we will never load
|
557 |
+
# the VAE so self.vae_config will not be set. So, we need to load it here.
|
558 |
+
vae_cls = resolve_vae_cls_from_ckpt_path(
|
559 |
+
self.args.pretrained_model_name_or_path, revision=self.args.revision, cache_dir=self.args.cache_dir
|
560 |
+
)
|
561 |
+
vae_config = vae_cls.load_config(
|
562 |
+
self.args.pretrained_model_name_or_path,
|
563 |
+
subfolder="vae",
|
564 |
+
revision=self.args.revision,
|
565 |
+
cache_dir=self.args.cache_dir,
|
566 |
+
)
|
567 |
+
self.vae_config = FrozenDict(**vae_config)
|
568 |
+
|
569 |
+
# In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
|
570 |
+
# to able to load all diffusion components from a specific checkpoint folder during validation, we need to
|
571 |
+
# ensure the scheduler config is serialized as well.
|
572 |
+
if self.args.training_type == "full-finetune":
|
573 |
+
self.scheduler.save_pretrained(os.path.join(self.args.output_dir, "scheduler"))
|
574 |
+
|
575 |
+
self.state.train_batch_size = (
|
576 |
+
self.args.batch_size * self.state.accelerator.num_processes * self.args.gradient_accumulation_steps
|
577 |
+
)
|
578 |
+
info = {
|
579 |
+
"trainable parameters": self.state.num_trainable_parameters,
|
580 |
+
"total samples": len(self.dataset),
|
581 |
+
"train epochs": self.state.train_epochs,
|
582 |
+
"train steps": self.state.train_steps,
|
583 |
+
"batches per device": self.args.batch_size,
|
584 |
+
"total batches observed per epoch": len(self.dataloader),
|
585 |
+
"train batch size": self.state.train_batch_size,
|
586 |
+
"gradient accumulation steps": self.args.gradient_accumulation_steps,
|
587 |
+
}
|
588 |
+
logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
|
589 |
+
|
590 |
+
global_step = 0
|
591 |
+
first_epoch = 0
|
592 |
+
initial_global_step = 0
|
593 |
+
|
594 |
+
# Potentially load in the weights and states from a previous save
|
595 |
+
(
|
596 |
+
resume_from_checkpoint_path,
|
597 |
+
initial_global_step,
|
598 |
+
global_step,
|
599 |
+
first_epoch,
|
600 |
+
) = get_latest_ckpt_path_to_resume_from(
|
601 |
+
resume_from_checkpoint=self.args.resume_from_checkpoint,
|
602 |
+
num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
|
603 |
+
output_dir=self.args.output_dir,
|
604 |
+
)
|
605 |
+
if resume_from_checkpoint_path:
|
606 |
+
self.state.accelerator.load_state(resume_from_checkpoint_path)
|
607 |
+
|
608 |
+
progress_bar = tqdm(
|
609 |
+
range(0, self.state.train_steps),
|
610 |
+
initial=initial_global_step,
|
611 |
+
desc="Training steps",
|
612 |
+
disable=not self.state.accelerator.is_local_main_process,
|
613 |
+
)
|
614 |
+
|
615 |
+
accelerator = self.state.accelerator
|
616 |
+
generator = torch.Generator(device=accelerator.device)
|
617 |
+
if self.args.seed is not None:
|
618 |
+
generator = generator.manual_seed(self.args.seed)
|
619 |
+
self.state.generator = generator
|
620 |
+
|
621 |
+
scheduler_sigmas = get_scheduler_sigmas(self.scheduler)
|
622 |
+
scheduler_sigmas = (
|
623 |
+
scheduler_sigmas.to(device=accelerator.device, dtype=torch.float32)
|
624 |
+
if scheduler_sigmas is not None
|
625 |
+
else None
|
626 |
+
)
|
627 |
+
scheduler_alphas = get_scheduler_alphas(self.scheduler)
|
628 |
+
scheduler_alphas = (
|
629 |
+
scheduler_alphas.to(device=accelerator.device, dtype=torch.float32)
|
630 |
+
if scheduler_alphas is not None
|
631 |
+
else None
|
632 |
+
)
|
633 |
+
|
634 |
+
for epoch in range(first_epoch, self.state.train_epochs):
|
635 |
+
logger.debug(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})")
|
636 |
+
|
637 |
+
self.transformer.train()
|
638 |
+
models_to_accumulate = [self.transformer]
|
639 |
+
epoch_loss = 0.0
|
640 |
+
num_loss_updates = 0
|
641 |
+
|
642 |
+
for step, batch in enumerate(self.dataloader):
|
643 |
+
logger.debug(f"Starting step {step + 1}")
|
644 |
+
logs = {}
|
645 |
+
|
646 |
+
with accelerator.accumulate(models_to_accumulate):
|
647 |
+
if not self.args.precompute_conditions:
|
648 |
+
videos = batch["videos"]
|
649 |
+
prompts = batch["prompts"]
|
650 |
+
batch_size = len(prompts)
|
651 |
+
|
652 |
+
if self.args.caption_dropout_technique == "empty":
|
653 |
+
if random.random() < self.args.caption_dropout_p:
|
654 |
+
prompts = [""] * batch_size
|
655 |
+
|
656 |
+
latent_conditions = self.model_config["prepare_latents"](
|
657 |
+
vae=self.vae,
|
658 |
+
image_or_video=videos,
|
659 |
+
patch_size=self.transformer_config.patch_size,
|
660 |
+
patch_size_t=self.transformer_config.patch_size_t,
|
661 |
+
device=accelerator.device,
|
662 |
+
dtype=self.args.transformer_dtype,
|
663 |
+
generator=self.state.generator,
|
664 |
+
)
|
665 |
+
text_conditions = self.model_config["prepare_conditions"](
|
666 |
+
tokenizer=self.tokenizer,
|
667 |
+
text_encoder=self.text_encoder,
|
668 |
+
tokenizer_2=self.tokenizer_2,
|
669 |
+
text_encoder_2=self.text_encoder_2,
|
670 |
+
prompt=prompts,
|
671 |
+
device=accelerator.device,
|
672 |
+
dtype=self.args.transformer_dtype,
|
673 |
+
)
|
674 |
+
else:
|
675 |
+
latent_conditions = batch["latent_conditions"]
|
676 |
+
text_conditions = batch["text_conditions"]
|
677 |
+
latent_conditions["latents"] = DiagonalGaussianDistribution(
|
678 |
+
latent_conditions["latents"]
|
679 |
+
).sample(self.state.generator)
|
680 |
+
|
681 |
+
# This method should only be called for precomputed latents.
|
682 |
+
# TODO(aryan): rename this in separate PR
|
683 |
+
latent_conditions = self.model_config["post_latent_preparation"](
|
684 |
+
vae_config=self.vae_config,
|
685 |
+
patch_size=self.transformer_config.patch_size,
|
686 |
+
patch_size_t=self.transformer_config.patch_size_t,
|
687 |
+
**latent_conditions,
|
688 |
+
)
|
689 |
+
align_device_and_dtype(latent_conditions, accelerator.device, self.args.transformer_dtype)
|
690 |
+
align_device_and_dtype(text_conditions, accelerator.device, self.args.transformer_dtype)
|
691 |
+
batch_size = latent_conditions["latents"].shape[0]
|
692 |
+
|
693 |
+
latent_conditions = make_contiguous(latent_conditions)
|
694 |
+
text_conditions = make_contiguous(text_conditions)
|
695 |
+
|
696 |
+
if self.args.caption_dropout_technique == "zero":
|
697 |
+
if random.random() < self.args.caption_dropout_p:
|
698 |
+
text_conditions["prompt_embeds"].fill_(0)
|
699 |
+
text_conditions["prompt_attention_mask"].fill_(False)
|
700 |
+
|
701 |
+
# TODO(aryan): refactor later
|
702 |
+
if "pooled_prompt_embeds" in text_conditions:
|
703 |
+
text_conditions["pooled_prompt_embeds"].fill_(0)
|
704 |
+
|
705 |
+
sigmas = prepare_sigmas(
|
706 |
+
scheduler=self.scheduler,
|
707 |
+
sigmas=scheduler_sigmas,
|
708 |
+
batch_size=batch_size,
|
709 |
+
num_train_timesteps=self.scheduler.config.num_train_timesteps,
|
710 |
+
flow_weighting_scheme=self.args.flow_weighting_scheme,
|
711 |
+
flow_logit_mean=self.args.flow_logit_mean,
|
712 |
+
flow_logit_std=self.args.flow_logit_std,
|
713 |
+
flow_mode_scale=self.args.flow_mode_scale,
|
714 |
+
device=accelerator.device,
|
715 |
+
generator=self.state.generator,
|
716 |
+
)
|
717 |
+
timesteps = (sigmas * 1000.0).long()
|
718 |
+
|
719 |
+
noise = torch.randn(
|
720 |
+
latent_conditions["latents"].shape,
|
721 |
+
generator=self.state.generator,
|
722 |
+
device=accelerator.device,
|
723 |
+
dtype=self.args.transformer_dtype,
|
724 |
+
)
|
725 |
+
sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim)
|
726 |
+
|
727 |
+
# TODO(aryan): We probably don't need calculate_noisy_latents because we can determine the type of
|
728 |
+
# scheduler and calculate the noisy latents accordingly. Look into this later.
|
729 |
+
if "calculate_noisy_latents" in self.model_config.keys():
|
730 |
+
noisy_latents = self.model_config["calculate_noisy_latents"](
|
731 |
+
scheduler=self.scheduler,
|
732 |
+
noise=noise,
|
733 |
+
latents=latent_conditions["latents"],
|
734 |
+
timesteps=timesteps,
|
735 |
+
)
|
736 |
+
else:
|
737 |
+
# Default to flow-matching noise addition
|
738 |
+
noisy_latents = (1.0 - sigmas) * latent_conditions["latents"] + sigmas * noise
|
739 |
+
noisy_latents = noisy_latents.to(latent_conditions["latents"].dtype)
|
740 |
+
|
741 |
+
latent_conditions.update({"noisy_latents": noisy_latents})
|
742 |
+
|
743 |
+
weights = prepare_loss_weights(
|
744 |
+
scheduler=self.scheduler,
|
745 |
+
alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None,
|
746 |
+
sigmas=sigmas,
|
747 |
+
flow_weighting_scheme=self.args.flow_weighting_scheme,
|
748 |
+
)
|
749 |
+
weights = expand_tensor_dims(weights, noise.ndim)
|
750 |
+
|
751 |
+
pred = self.model_config["forward_pass"](
|
752 |
+
transformer=self.transformer,
|
753 |
+
scheduler=self.scheduler,
|
754 |
+
timesteps=timesteps,
|
755 |
+
**latent_conditions,
|
756 |
+
**text_conditions,
|
757 |
+
)
|
758 |
+
target = prepare_target(
|
759 |
+
scheduler=self.scheduler, noise=noise, latents=latent_conditions["latents"]
|
760 |
+
)
|
761 |
+
|
762 |
+
loss = weights.float() * (pred["latents"].float() - target.float()).pow(2)
|
763 |
+
# Average loss across all but batch dimension
|
764 |
+
loss = loss.mean(list(range(1, loss.ndim)))
|
765 |
+
# Average loss across batch dimension
|
766 |
+
loss = loss.mean()
|
767 |
+
accelerator.backward(loss)
|
768 |
+
|
769 |
+
if accelerator.sync_gradients:
|
770 |
+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
771 |
+
grad_norm = self.transformer.get_global_grad_norm()
|
772 |
+
# In some cases the grad norm may not return a float
|
773 |
+
if torch.is_tensor(grad_norm):
|
774 |
+
grad_norm = grad_norm.item()
|
775 |
+
else:
|
776 |
+
grad_norm = accelerator.clip_grad_norm_(
|
777 |
+
self.transformer.parameters(), self.args.max_grad_norm
|
778 |
+
)
|
779 |
+
if torch.is_tensor(grad_norm):
|
780 |
+
grad_norm = grad_norm.item()
|
781 |
+
|
782 |
+
logs["grad_norm"] = grad_norm
|
783 |
+
|
784 |
+
self.optimizer.step()
|
785 |
+
self.lr_scheduler.step()
|
786 |
+
self.optimizer.zero_grad()
|
787 |
+
|
788 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
789 |
+
if accelerator.sync_gradients:
|
790 |
+
progress_bar.update(1)
|
791 |
+
global_step += 1
|
792 |
+
|
793 |
+
# Checkpointing
|
794 |
+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
795 |
+
if global_step % self.args.checkpointing_steps == 0:
|
796 |
+
save_path = get_intermediate_ckpt_path(
|
797 |
+
checkpointing_limit=self.args.checkpointing_limit,
|
798 |
+
step=global_step,
|
799 |
+
output_dir=self.args.output_dir,
|
800 |
+
)
|
801 |
+
accelerator.save_state(save_path)
|
802 |
+
|
803 |
+
# Maybe run validation
|
804 |
+
should_run_validation = (
|
805 |
+
self.args.validation_every_n_steps is not None
|
806 |
+
and global_step % self.args.validation_every_n_steps == 0
|
807 |
+
)
|
808 |
+
if should_run_validation:
|
809 |
+
self.validate(global_step)
|
810 |
+
|
811 |
+
loss_item = loss.detach().item()
|
812 |
+
epoch_loss += loss_item
|
813 |
+
num_loss_updates += 1
|
814 |
+
logs["step_loss"] = loss_item
|
815 |
+
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
|
816 |
+
progress_bar.set_postfix(logs)
|
817 |
+
accelerator.log(logs, step=global_step)
|
818 |
+
|
819 |
+
if global_step >= self.state.train_steps:
|
820 |
+
break
|
821 |
+
|
822 |
+
if num_loss_updates > 0:
|
823 |
+
epoch_loss /= num_loss_updates
|
824 |
+
accelerator.log({"epoch_loss": epoch_loss}, step=global_step)
|
825 |
+
memory_statistics = get_memory_statistics()
|
826 |
+
logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
|
827 |
+
|
828 |
+
# Maybe run validation
|
829 |
+
should_run_validation = (
|
830 |
+
self.args.validation_every_n_epochs is not None
|
831 |
+
and (epoch + 1) % self.args.validation_every_n_epochs == 0
|
832 |
+
)
|
833 |
+
if should_run_validation:
|
834 |
+
self.validate(global_step)
|
835 |
+
|
836 |
+
accelerator.wait_for_everyone()
|
837 |
+
if accelerator.is_main_process:
|
838 |
+
transformer = unwrap_model(accelerator, self.transformer)
|
839 |
+
|
840 |
+
if self.args.training_type == "lora":
|
841 |
+
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
842 |
+
|
843 |
+
self.model_config["pipeline_cls"].save_lora_weights(
|
844 |
+
save_directory=self.args.output_dir,
|
845 |
+
transformer_lora_layers=transformer_lora_layers,
|
846 |
+
)
|
847 |
+
else:
|
848 |
+
transformer.save_pretrained(os.path.join(self.args.output_dir, "transformer"))
|
849 |
+
accelerator.wait_for_everyone()
|
850 |
+
self.validate(step=global_step, final_validation=True)
|
851 |
+
|
852 |
+
if accelerator.is_main_process:
|
853 |
+
if self.args.push_to_hub:
|
854 |
+
upload_folder(
|
855 |
+
repo_id=self.state.repo_id, folder_path=self.args.output_dir, ignore_patterns=["checkpoint-*"]
|
856 |
+
)
|
857 |
+
|
858 |
+
self._delete_components()
|
859 |
+
memory_statistics = get_memory_statistics()
|
860 |
+
logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
|
861 |
+
|
862 |
+
accelerator.end_training()
|
863 |
+
|
864 |
+
def validate(self, step: int, final_validation: bool = False) -> None:
|
865 |
+
logger.info("Starting validation")
|
866 |
+
|
867 |
+
accelerator = self.state.accelerator
|
868 |
+
num_validation_samples = len(self.args.validation_prompts)
|
869 |
+
|
870 |
+
if num_validation_samples == 0:
|
871 |
+
logger.warning("No validation samples found. Skipping validation.")
|
872 |
+
if accelerator.is_main_process:
|
873 |
+
if self.args.push_to_hub:
|
874 |
+
save_model_card(
|
875 |
+
args=self.args,
|
876 |
+
repo_id=self.state.repo_id,
|
877 |
+
videos=None,
|
878 |
+
validation_prompts=None,
|
879 |
+
)
|
880 |
+
return
|
881 |
+
|
882 |
+
self.transformer.eval()
|
883 |
+
|
884 |
+
memory_statistics = get_memory_statistics()
|
885 |
+
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
|
886 |
+
|
887 |
+
pipeline = self._get_and_prepare_pipeline_for_validation(final_validation=final_validation)
|
888 |
+
|
889 |
+
all_processes_artifacts = []
|
890 |
+
prompts_to_filenames = {}
|
891 |
+
for i in range(num_validation_samples):
|
892 |
+
# Skip current validation on all processes but one
|
893 |
+
if i % accelerator.num_processes != accelerator.process_index:
|
894 |
+
continue
|
895 |
+
|
896 |
+
prompt = self.args.validation_prompts[i]
|
897 |
+
image = self.args.validation_images[i]
|
898 |
+
video = self.args.validation_videos[i]
|
899 |
+
height = self.args.validation_heights[i]
|
900 |
+
width = self.args.validation_widths[i]
|
901 |
+
num_frames = self.args.validation_num_frames[i]
|
902 |
+
frame_rate = self.args.validation_frame_rate
|
903 |
+
if image is not None:
|
904 |
+
image = load_image(image)
|
905 |
+
if video is not None:
|
906 |
+
video = load_video(video)
|
907 |
+
|
908 |
+
logger.debug(
|
909 |
+
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
910 |
+
main_process_only=False,
|
911 |
+
)
|
912 |
+
validation_artifacts = self.model_config["validation"](
|
913 |
+
pipeline=pipeline,
|
914 |
+
prompt=prompt,
|
915 |
+
image=image,
|
916 |
+
video=video,
|
917 |
+
height=height,
|
918 |
+
width=width,
|
919 |
+
num_frames=num_frames,
|
920 |
+
frame_rate=frame_rate,
|
921 |
+
num_videos_per_prompt=self.args.num_validation_videos_per_prompt,
|
922 |
+
generator=torch.Generator(device=accelerator.device).manual_seed(
|
923 |
+
self.args.seed if self.args.seed is not None else 0
|
924 |
+
),
|
925 |
+
# todo support passing `fps` for supported pipelines.
|
926 |
+
)
|
927 |
+
|
928 |
+
prompt_filename = string_to_filename(prompt)[:25]
|
929 |
+
artifacts = {
|
930 |
+
"image": {"type": "image", "value": image},
|
931 |
+
"video": {"type": "video", "value": video},
|
932 |
+
}
|
933 |
+
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
|
934 |
+
if artifact_value:
|
935 |
+
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
|
936 |
+
logger.debug(
|
937 |
+
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
|
938 |
+
main_process_only=False,
|
939 |
+
)
|
940 |
+
|
941 |
+
for index, (key, value) in enumerate(list(artifacts.items())):
|
942 |
+
artifact_type = value["type"]
|
943 |
+
artifact_value = value["value"]
|
944 |
+
if artifact_type not in ["image", "video"] or artifact_value is None:
|
945 |
+
continue
|
946 |
+
|
947 |
+
extension = "png" if artifact_type == "image" else "mp4"
|
948 |
+
filename = "validation-" if not final_validation else "final-"
|
949 |
+
filename += f"{step}-{accelerator.process_index}-{index}-{prompt_filename}.{extension}"
|
950 |
+
if accelerator.is_main_process and extension == "mp4":
|
951 |
+
prompts_to_filenames[prompt] = filename
|
952 |
+
filename = os.path.join(self.args.output_dir, filename)
|
953 |
+
|
954 |
+
if artifact_type == "image" and artifact_value:
|
955 |
+
logger.debug(f"Saving image to {filename}")
|
956 |
+
artifact_value.save(filename)
|
957 |
+
artifact_value = wandb.Image(filename)
|
958 |
+
elif artifact_type == "video" and artifact_value:
|
959 |
+
logger.debug(f"Saving video to {filename}")
|
960 |
+
# TODO: this should be configurable here as well as in validation runs where we call the pipeline that has `fps`.
|
961 |
+
export_to_video(artifact_value, filename, fps=frame_rate)
|
962 |
+
artifact_value = wandb.Video(filename, caption=prompt)
|
963 |
+
|
964 |
+
all_processes_artifacts.append(artifact_value)
|
965 |
+
|
966 |
+
all_artifacts = gather_object(all_processes_artifacts)
|
967 |
+
|
968 |
+
if accelerator.is_main_process:
|
969 |
+
tracker_key = "final" if final_validation else "validation"
|
970 |
+
for tracker in accelerator.trackers:
|
971 |
+
if tracker.name == "wandb":
|
972 |
+
artifact_log_dict = {}
|
973 |
+
|
974 |
+
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
|
975 |
+
if len(image_artifacts) > 0:
|
976 |
+
artifact_log_dict["images"] = image_artifacts
|
977 |
+
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
|
978 |
+
if len(video_artifacts) > 0:
|
979 |
+
artifact_log_dict["videos"] = video_artifacts
|
980 |
+
tracker.log({tracker_key: artifact_log_dict}, step=step)
|
981 |
+
|
982 |
+
if self.args.push_to_hub and final_validation:
|
983 |
+
video_filenames = list(prompts_to_filenames.values())
|
984 |
+
prompts = list(prompts_to_filenames.keys())
|
985 |
+
save_model_card(
|
986 |
+
args=self.args,
|
987 |
+
repo_id=self.state.repo_id,
|
988 |
+
videos=video_filenames,
|
989 |
+
validation_prompts=prompts,
|
990 |
+
)
|
991 |
+
|
992 |
+
# Remove all hooks that might have been added during pipeline initialization to the models
|
993 |
+
pipeline.remove_all_hooks()
|
994 |
+
del pipeline
|
995 |
+
|
996 |
+
accelerator.wait_for_everyone()
|
997 |
+
|
998 |
+
free_memory()
|
999 |
+
memory_statistics = get_memory_statistics()
|
1000 |
+
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
1001 |
+
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
1002 |
+
|
1003 |
+
if not final_validation:
|
1004 |
+
self.transformer.train()
|
1005 |
+
|
1006 |
+
def evaluate(self) -> None:
|
1007 |
+
raise NotImplementedError("Evaluation has not been implemented yet.")
|
1008 |
+
|
1009 |
+
def _init_distributed(self) -> None:
|
1010 |
+
logging_dir = Path(self.args.output_dir, self.args.logging_dir)
|
1011 |
+
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
|
1012 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
1013 |
+
init_process_group_kwargs = InitProcessGroupKwargs(
|
1014 |
+
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
|
1015 |
+
)
|
1016 |
+
report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
|
1017 |
+
|
1018 |
+
accelerator = Accelerator(
|
1019 |
+
project_config=project_config,
|
1020 |
+
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
|
1021 |
+
log_with=report_to,
|
1022 |
+
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
|
1023 |
+
)
|
1024 |
+
|
1025 |
+
# Disable AMP for MPS.
|
1026 |
+
if torch.backends.mps.is_available():
|
1027 |
+
accelerator.native_amp = False
|
1028 |
+
|
1029 |
+
self.state.accelerator = accelerator
|
1030 |
+
|
1031 |
+
if self.args.seed is not None:
|
1032 |
+
self.state.seed = self.args.seed
|
1033 |
+
set_seed(self.args.seed)
|
1034 |
+
|
1035 |
+
def _init_logging(self) -> None:
|
1036 |
+
logging.basicConfig(
|
1037 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
1038 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
1039 |
+
level=FINETRAINERS_LOG_LEVEL,
|
1040 |
+
)
|
1041 |
+
if self.state.accelerator.is_local_main_process:
|
1042 |
+
transformers.utils.logging.set_verbosity_warning()
|
1043 |
+
diffusers.utils.logging.set_verbosity_info()
|
1044 |
+
else:
|
1045 |
+
transformers.utils.logging.set_verbosity_error()
|
1046 |
+
diffusers.utils.logging.set_verbosity_error()
|
1047 |
+
|
1048 |
+
logger.info("Initialized FineTrainers")
|
1049 |
+
logger.info(self.state.accelerator.state, main_process_only=False)
|
1050 |
+
|
1051 |
+
def _init_directories_and_repositories(self) -> None:
|
1052 |
+
if self.state.accelerator.is_main_process:
|
1053 |
+
self.args.output_dir = Path(self.args.output_dir)
|
1054 |
+
self.args.output_dir.mkdir(parents=True, exist_ok=True)
|
1055 |
+
self.state.output_dir = Path(self.args.output_dir)
|
1056 |
+
|
1057 |
+
if self.args.push_to_hub:
|
1058 |
+
repo_id = self.args.hub_model_id or Path(self.args.output_dir).name
|
1059 |
+
self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id
|
1060 |
+
|
1061 |
+
def _init_config_options(self) -> None:
|
1062 |
+
# Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
1063 |
+
if self.args.allow_tf32 and torch.cuda.is_available():
|
1064 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
1065 |
+
|
1066 |
+
def _move_components_to_device(self):
|
1067 |
+
if self.text_encoder is not None:
|
1068 |
+
self.text_encoder = self.text_encoder.to(self.state.accelerator.device)
|
1069 |
+
if self.text_encoder_2 is not None:
|
1070 |
+
self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device)
|
1071 |
+
if self.text_encoder_3 is not None:
|
1072 |
+
self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device)
|
1073 |
+
if self.transformer is not None:
|
1074 |
+
self.transformer = self.transformer.to(self.state.accelerator.device)
|
1075 |
+
if self.unet is not None:
|
1076 |
+
self.unet = self.unet.to(self.state.accelerator.device)
|
1077 |
+
if self.vae is not None:
|
1078 |
+
self.vae = self.vae.to(self.state.accelerator.device)
|
1079 |
+
|
1080 |
+
def _get_load_components_kwargs(self) -> Dict[str, Any]:
|
1081 |
+
load_component_kwargs = {
|
1082 |
+
"text_encoder_dtype": self.args.text_encoder_dtype,
|
1083 |
+
"text_encoder_2_dtype": self.args.text_encoder_2_dtype,
|
1084 |
+
"text_encoder_3_dtype": self.args.text_encoder_3_dtype,
|
1085 |
+
"transformer_dtype": self.args.transformer_dtype,
|
1086 |
+
"vae_dtype": self.args.vae_dtype,
|
1087 |
+
"shift": self.args.flow_shift,
|
1088 |
+
"revision": self.args.revision,
|
1089 |
+
"cache_dir": self.args.cache_dir,
|
1090 |
+
}
|
1091 |
+
if self.args.pretrained_model_name_or_path is not None:
|
1092 |
+
load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path
|
1093 |
+
return load_component_kwargs
|
1094 |
+
|
1095 |
+
def _set_components(self, components: Dict[str, Any]) -> None:
|
1096 |
+
# Set models
|
1097 |
+
self.tokenizer = components.get("tokenizer", self.tokenizer)
|
1098 |
+
self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2)
|
1099 |
+
self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3)
|
1100 |
+
self.text_encoder = components.get("text_encoder", self.text_encoder)
|
1101 |
+
self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2)
|
1102 |
+
self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3)
|
1103 |
+
self.transformer = components.get("transformer", self.transformer)
|
1104 |
+
self.unet = components.get("unet", self.unet)
|
1105 |
+
self.vae = components.get("vae", self.vae)
|
1106 |
+
self.scheduler = components.get("scheduler", self.scheduler)
|
1107 |
+
|
1108 |
+
# Set configs
|
1109 |
+
self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config
|
1110 |
+
self.vae_config = self.vae.config if self.vae is not None else self.vae_config
|
1111 |
+
|
1112 |
+
def _delete_components(self) -> None:
|
1113 |
+
self.tokenizer = None
|
1114 |
+
self.tokenizer_2 = None
|
1115 |
+
self.tokenizer_3 = None
|
1116 |
+
self.text_encoder = None
|
1117 |
+
self.text_encoder_2 = None
|
1118 |
+
self.text_encoder_3 = None
|
1119 |
+
self.transformer = None
|
1120 |
+
self.unet = None
|
1121 |
+
self.vae = None
|
1122 |
+
self.scheduler = None
|
1123 |
+
free_memory()
|
1124 |
+
torch.cuda.synchronize(self.state.accelerator.device)
|
1125 |
+
|
1126 |
+
def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline:
|
1127 |
+
accelerator = self.state.accelerator
|
1128 |
+
if not final_validation:
|
1129 |
+
pipeline = self.model_config["initialize_pipeline"](
|
1130 |
+
model_id=self.args.pretrained_model_name_or_path,
|
1131 |
+
tokenizer=self.tokenizer,
|
1132 |
+
text_encoder=self.text_encoder,
|
1133 |
+
tokenizer_2=self.tokenizer_2,
|
1134 |
+
text_encoder_2=self.text_encoder_2,
|
1135 |
+
transformer=unwrap_model(accelerator, self.transformer),
|
1136 |
+
vae=self.vae,
|
1137 |
+
device=accelerator.device,
|
1138 |
+
revision=self.args.revision,
|
1139 |
+
cache_dir=self.args.cache_dir,
|
1140 |
+
enable_slicing=self.args.enable_slicing,
|
1141 |
+
enable_tiling=self.args.enable_tiling,
|
1142 |
+
enable_model_cpu_offload=self.args.enable_model_cpu_offload,
|
1143 |
+
is_training=True,
|
1144 |
+
)
|
1145 |
+
else:
|
1146 |
+
self._delete_components()
|
1147 |
+
|
1148 |
+
# Load the transformer weights from the final checkpoint if performing full-finetune
|
1149 |
+
transformer = None
|
1150 |
+
if self.args.training_type == "full-finetune":
|
1151 |
+
transformer = self.model_config["load_diffusion_models"](model_id=self.args.output_dir)["transformer"]
|
1152 |
+
|
1153 |
+
pipeline = self.model_config["initialize_pipeline"](
|
1154 |
+
model_id=self.args.pretrained_model_name_or_path,
|
1155 |
+
transformer=transformer,
|
1156 |
+
device=accelerator.device,
|
1157 |
+
revision=self.args.revision,
|
1158 |
+
cache_dir=self.args.cache_dir,
|
1159 |
+
enable_slicing=self.args.enable_slicing,
|
1160 |
+
enable_tiling=self.args.enable_tiling,
|
1161 |
+
enable_model_cpu_offload=self.args.enable_model_cpu_offload,
|
1162 |
+
is_training=False,
|
1163 |
+
)
|
1164 |
+
|
1165 |
+
# Load the LoRA weights if performing LoRA finetuning
|
1166 |
+
if self.args.training_type == "lora":
|
1167 |
+
pipeline.load_lora_weights(self.args.output_dir)
|
1168 |
+
|
1169 |
+
return pipeline
|
1170 |
+
|
1171 |
+
def _disable_grad_for_components(self, components: List[torch.nn.Module]):
|
1172 |
+
for component in components:
|
1173 |
+
if component is not None:
|
1174 |
+
component.requires_grad_(False)
|
1175 |
+
|
1176 |
+
def _enable_grad_for_components(self, components: List[torch.nn.Module]):
|
1177 |
+
for component in components:
|
1178 |
+
if component is not None:
|
1179 |
+
component.requires_grad_(True)
|
1180 |
+
|
1181 |
+
def _get_training_info(self) -> dict:
|
1182 |
+
args = self.args.to_dict()
|
1183 |
+
|
1184 |
+
training_args = args.get("training_arguments", {})
|
1185 |
+
training_type = training_args.get("training_type", "")
|
1186 |
+
|
1187 |
+
# LoRA/non-LoRA stuff.
|
1188 |
+
if training_type == "full-finetune":
|
1189 |
+
filtered_training_args = {
|
1190 |
+
k: v for k, v in training_args.items() if k not in {"rank", "lora_alpha", "target_modules"}
|
1191 |
+
}
|
1192 |
+
else:
|
1193 |
+
filtered_training_args = training_args
|
1194 |
+
|
1195 |
+
# Diffusion/flow stuff.
|
1196 |
+
diffusion_args = args.get("diffusion_arguments", {})
|
1197 |
+
scheduler_name = self.scheduler.__class__.__name__
|
1198 |
+
if scheduler_name != "FlowMatchEulerDiscreteScheduler":
|
1199 |
+
filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k}
|
1200 |
+
else:
|
1201 |
+
filtered_diffusion_args = diffusion_args
|
1202 |
+
|
1203 |
+
# Rest of the stuff.
|
1204 |
+
updated_training_info = args.copy()
|
1205 |
+
updated_training_info["training_arguments"] = filtered_training_args
|
1206 |
+
updated_training_info["diffusion_arguments"] = filtered_diffusion_args
|
1207 |
+
return updated_training_info
|
finetrainers/utils/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .diffusion_utils import (
|
2 |
+
default_flow_shift,
|
3 |
+
get_scheduler_alphas,
|
4 |
+
get_scheduler_sigmas,
|
5 |
+
prepare_loss_weights,
|
6 |
+
prepare_sigmas,
|
7 |
+
prepare_target,
|
8 |
+
resolution_dependent_timestep_flow_shift,
|
9 |
+
)
|
10 |
+
from .file_utils import delete_files, find_files
|
11 |
+
from .memory_utils import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous
|
12 |
+
from .optimizer_utils import get_optimizer, gradient_norm, max_gradient
|
13 |
+
from .torch_utils import unwrap_model
|
finetrainers/utils/checkpointing.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
from accelerate.logging import get_logger
|
5 |
+
|
6 |
+
from ..constants import FINETRAINERS_LOG_LEVEL
|
7 |
+
from ..utils.file_utils import delete_files, find_files
|
8 |
+
|
9 |
+
|
10 |
+
logger = get_logger("finetrainers")
|
11 |
+
logger.setLevel(FINETRAINERS_LOG_LEVEL)
|
12 |
+
|
13 |
+
|
14 |
+
def get_latest_ckpt_path_to_resume_from(
|
15 |
+
resume_from_checkpoint: str, num_update_steps_per_epoch: int, output_dir: str
|
16 |
+
) -> Tuple[str, int, int, int]:
|
17 |
+
if not resume_from_checkpoint:
|
18 |
+
initial_global_step = 0
|
19 |
+
global_step = 0
|
20 |
+
first_epoch = 0
|
21 |
+
resume_from_checkpoint_path = None
|
22 |
+
else:
|
23 |
+
if resume_from_checkpoint != "latest":
|
24 |
+
path = os.path.basename(resume_from_checkpoint)
|
25 |
+
else:
|
26 |
+
# Get the most recent checkpoint
|
27 |
+
dirs = os.listdir(output_dir)
|
28 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
29 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
30 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
31 |
+
|
32 |
+
if path is None:
|
33 |
+
logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
|
34 |
+
resume_from_checkpoint = None
|
35 |
+
initial_global_step = 0
|
36 |
+
global_step = 0
|
37 |
+
first_epoch = 0
|
38 |
+
resume_from_checkpoint_path = None
|
39 |
+
else:
|
40 |
+
logger.info(f"Resuming from checkpoint {path}")
|
41 |
+
resume_from_checkpoint_path = os.path.join(output_dir, path)
|
42 |
+
global_step = int(path.split("-")[1])
|
43 |
+
|
44 |
+
initial_global_step = global_step
|
45 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
46 |
+
|
47 |
+
return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch
|
48 |
+
|
49 |
+
|
50 |
+
def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str:
|
51 |
+
# before saving state, check if this save would set us over the `checkpointing_limit`
|
52 |
+
if checkpointing_limit is not None:
|
53 |
+
checkpoints = find_files(output_dir, prefix="checkpoint")
|
54 |
+
|
55 |
+
# before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
|
56 |
+
if len(checkpoints) >= checkpointing_limit:
|
57 |
+
num_to_remove = len(checkpoints) - checkpointing_limit + 1
|
58 |
+
checkpoints_to_remove = [os.path.join(output_dir, x) for x in checkpoints[0:num_to_remove]]
|
59 |
+
delete_files(checkpoints_to_remove)
|
60 |
+
|
61 |
+
logger.info(f"Checkpointing at step {step}")
|
62 |
+
save_path = os.path.join(output_dir, f"checkpoint-{step}")
|
63 |
+
logger.info(f"Saving state to {save_path}")
|
64 |
+
return save_path
|
finetrainers/utils/data_utils.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
from accelerate.logging import get_logger
|
5 |
+
|
6 |
+
from ..constants import PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME
|
7 |
+
|
8 |
+
|
9 |
+
logger = get_logger("finetrainers")
|
10 |
+
|
11 |
+
|
12 |
+
def should_perform_precomputation(precomputation_dir: Union[str, Path]) -> bool:
|
13 |
+
if isinstance(precomputation_dir, str):
|
14 |
+
precomputation_dir = Path(precomputation_dir)
|
15 |
+
conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
|
16 |
+
latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
|
17 |
+
if conditions_dir.exists() and latents_dir.exists():
|
18 |
+
num_files_conditions = len(list(conditions_dir.glob("*.pt")))
|
19 |
+
num_files_latents = len(list(latents_dir.glob("*.pt")))
|
20 |
+
if num_files_conditions != num_files_latents:
|
21 |
+
logger.warning(
|
22 |
+
f"Number of precomputed conditions ({num_files_conditions}) does not match number of precomputed latents ({num_files_latents})."
|
23 |
+
f"Cleaning up precomputed directories and re-running precomputation."
|
24 |
+
)
|
25 |
+
# clean up precomputed directories
|
26 |
+
for file in conditions_dir.glob("*.pt"):
|
27 |
+
file.unlink()
|
28 |
+
for file in latents_dir.glob("*.pt"):
|
29 |
+
file.unlink()
|
30 |
+
return True
|
31 |
+
if num_files_conditions > 0:
|
32 |
+
logger.info(f"Found {num_files_conditions} precomputed conditions and latents.")
|
33 |
+
return False
|
34 |
+
logger.info("Precomputed data not found. Running precomputation.")
|
35 |
+
return True
|
finetrainers/utils/diffusion_utils.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler
|
6 |
+
from diffusers.training_utils import compute_loss_weighting_for_sd3
|
7 |
+
|
8 |
+
|
9 |
+
# Default values copied from https://github.com/huggingface/diffusers/blob/8957324363d8b239d82db4909fbf8c0875683e3d/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L47
|
10 |
+
def resolution_dependent_timestep_flow_shift(
|
11 |
+
latents: torch.Tensor,
|
12 |
+
sigmas: torch.Tensor,
|
13 |
+
base_image_seq_len: int = 256,
|
14 |
+
max_image_seq_len: int = 4096,
|
15 |
+
base_shift: float = 0.5,
|
16 |
+
max_shift: float = 1.15,
|
17 |
+
) -> torch.Tensor:
|
18 |
+
image_or_video_sequence_length = 0
|
19 |
+
if latents.ndim == 4:
|
20 |
+
image_or_video_sequence_length = latents.shape[2] * latents.shape[3]
|
21 |
+
elif latents.ndim == 5:
|
22 |
+
image_or_video_sequence_length = latents.shape[2] * latents.shape[3] * latents.shape[4]
|
23 |
+
else:
|
24 |
+
raise ValueError(f"Expected 4D or 5D tensor, got {latents.ndim}D tensor")
|
25 |
+
|
26 |
+
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
|
27 |
+
b = base_shift - m * base_image_seq_len
|
28 |
+
mu = m * image_or_video_sequence_length + b
|
29 |
+
sigmas = default_flow_shift(latents, sigmas, shift=mu)
|
30 |
+
return sigmas
|
31 |
+
|
32 |
+
|
33 |
+
def default_flow_shift(sigmas: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
|
34 |
+
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
|
35 |
+
return sigmas
|
36 |
+
|
37 |
+
|
38 |
+
def compute_density_for_timestep_sampling(
|
39 |
+
weighting_scheme: str,
|
40 |
+
batch_size: int,
|
41 |
+
logit_mean: float = None,
|
42 |
+
logit_std: float = None,
|
43 |
+
mode_scale: float = None,
|
44 |
+
device: torch.device = torch.device("cpu"),
|
45 |
+
generator: Optional[torch.Generator] = None,
|
46 |
+
) -> torch.Tensor:
|
47 |
+
r"""
|
48 |
+
Compute the density for sampling the timesteps when doing SD3 training.
|
49 |
+
|
50 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
51 |
+
|
52 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
53 |
+
"""
|
54 |
+
if weighting_scheme == "logit_normal":
|
55 |
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
56 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
|
57 |
+
u = torch.nn.functional.sigmoid(u)
|
58 |
+
elif weighting_scheme == "mode":
|
59 |
+
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
60 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
61 |
+
else:
|
62 |
+
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
63 |
+
return u
|
64 |
+
|
65 |
+
|
66 |
+
def get_scheduler_alphas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
|
67 |
+
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
|
68 |
+
return None
|
69 |
+
elif isinstance(scheduler, CogVideoXDDIMScheduler):
|
70 |
+
return scheduler.alphas_cumprod.clone()
|
71 |
+
else:
|
72 |
+
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
|
73 |
+
|
74 |
+
|
75 |
+
def get_scheduler_sigmas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
|
76 |
+
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
|
77 |
+
return scheduler.sigmas.clone()
|
78 |
+
elif isinstance(scheduler, CogVideoXDDIMScheduler):
|
79 |
+
return scheduler.timesteps.clone().float() / float(scheduler.config.num_train_timesteps)
|
80 |
+
else:
|
81 |
+
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
|
82 |
+
|
83 |
+
|
84 |
+
def prepare_sigmas(
|
85 |
+
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
|
86 |
+
sigmas: torch.Tensor,
|
87 |
+
batch_size: int,
|
88 |
+
num_train_timesteps: int,
|
89 |
+
flow_weighting_scheme: str = "none",
|
90 |
+
flow_logit_mean: float = 0.0,
|
91 |
+
flow_logit_std: float = 1.0,
|
92 |
+
flow_mode_scale: float = 1.29,
|
93 |
+
device: torch.device = torch.device("cpu"),
|
94 |
+
generator: Optional[torch.Generator] = None,
|
95 |
+
) -> torch.Tensor:
|
96 |
+
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
|
97 |
+
weights = compute_density_for_timestep_sampling(
|
98 |
+
weighting_scheme=flow_weighting_scheme,
|
99 |
+
batch_size=batch_size,
|
100 |
+
logit_mean=flow_logit_mean,
|
101 |
+
logit_std=flow_logit_std,
|
102 |
+
mode_scale=flow_mode_scale,
|
103 |
+
device=device,
|
104 |
+
generator=generator,
|
105 |
+
)
|
106 |
+
indices = (weights * num_train_timesteps).long()
|
107 |
+
elif isinstance(scheduler, CogVideoXDDIMScheduler):
|
108 |
+
# TODO(aryan): Currently, only uniform sampling is supported. Add more sampling schemes.
|
109 |
+
weights = torch.rand(size=(batch_size,), device=device, generator=generator)
|
110 |
+
indices = (weights * num_train_timesteps).long()
|
111 |
+
else:
|
112 |
+
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
|
113 |
+
|
114 |
+
return sigmas[indices]
|
115 |
+
|
116 |
+
|
117 |
+
def prepare_loss_weights(
|
118 |
+
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
|
119 |
+
alphas: Optional[torch.Tensor] = None,
|
120 |
+
sigmas: Optional[torch.Tensor] = None,
|
121 |
+
flow_weighting_scheme: str = "none",
|
122 |
+
) -> torch.Tensor:
|
123 |
+
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
|
124 |
+
return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme)
|
125 |
+
elif isinstance(scheduler, CogVideoXDDIMScheduler):
|
126 |
+
# SNR is computed as (alphas / (1 - alphas)), but for some reason CogVideoX uses 1 / (1 - alphas).
|
127 |
+
# TODO(aryan): Experiment if using alphas / (1 - alphas) gives better results.
|
128 |
+
return 1 / (1 - alphas)
|
129 |
+
else:
|
130 |
+
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
|
131 |
+
|
132 |
+
|
133 |
+
def prepare_target(
|
134 |
+
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
|
135 |
+
noise: torch.Tensor,
|
136 |
+
latents: torch.Tensor,
|
137 |
+
) -> torch.Tensor:
|
138 |
+
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
|
139 |
+
target = noise - latents
|
140 |
+
elif isinstance(scheduler, CogVideoXDDIMScheduler):
|
141 |
+
target = latents
|
142 |
+
else:
|
143 |
+
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
|
144 |
+
|
145 |
+
return target
|
finetrainers/utils/file_utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import List, Union
|
6 |
+
|
7 |
+
|
8 |
+
logger = logging.getLogger("finetrainers")
|
9 |
+
logger.setLevel(os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO"))
|
10 |
+
|
11 |
+
|
12 |
+
def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]:
|
13 |
+
if not isinstance(dir, Path):
|
14 |
+
dir = Path(dir)
|
15 |
+
if not dir.exists():
|
16 |
+
return []
|
17 |
+
checkpoints = os.listdir(dir.as_posix())
|
18 |
+
checkpoints = [c for c in checkpoints if c.startswith(prefix)]
|
19 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
20 |
+
return checkpoints
|
21 |
+
|
22 |
+
|
23 |
+
def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None:
|
24 |
+
if not isinstance(dirs, list):
|
25 |
+
dirs = [dirs]
|
26 |
+
dirs = [Path(d) if isinstance(d, str) else d for d in dirs]
|
27 |
+
logger.info(f"Deleting files: {dirs}")
|
28 |
+
for dir in dirs:
|
29 |
+
if not dir.exists():
|
30 |
+
continue
|
31 |
+
shutil.rmtree(dir, ignore_errors=True)
|
32 |
+
|
33 |
+
|
34 |
+
def string_to_filename(s: str) -> str:
|
35 |
+
return (
|
36 |
+
s.replace(" ", "-")
|
37 |
+
.replace("/", "-")
|
38 |
+
.replace(":", "-")
|
39 |
+
.replace(".", "-")
|
40 |
+
.replace(",", "-")
|
41 |
+
.replace(";", "-")
|
42 |
+
.replace("!", "-")
|
43 |
+
.replace("?", "-")
|
44 |
+
)
|
finetrainers/utils/hub_utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import wandb
|
6 |
+
from diffusers.utils import export_to_video
|
7 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
def save_model_card(
|
12 |
+
args,
|
13 |
+
repo_id: str,
|
14 |
+
videos: Union[List[str], Union[List[Image.Image], List[np.ndarray]]],
|
15 |
+
validation_prompts: List[str],
|
16 |
+
fps: int = 30,
|
17 |
+
) -> None:
|
18 |
+
widget_dict = []
|
19 |
+
output_dir = str(args.output_dir)
|
20 |
+
if videos is not None and len(videos) > 0:
|
21 |
+
for i, (video, validation_prompt) in enumerate(zip(videos, validation_prompts)):
|
22 |
+
if not isinstance(video, str):
|
23 |
+
export_to_video(video, os.path.join(output_dir, f"final_video_{i}.mp4"), fps=fps)
|
24 |
+
widget_dict.append(
|
25 |
+
{
|
26 |
+
"text": validation_prompt if validation_prompt else " ",
|
27 |
+
"output": {"url": video if isinstance(video, str) else f"final_video_{i}.mp4"},
|
28 |
+
}
|
29 |
+
)
|
30 |
+
|
31 |
+
training_type = "Full" if args.training_type == "full-finetune" else "LoRA"
|
32 |
+
model_description = f"""
|
33 |
+
# {training_type} Finetune
|
34 |
+
|
35 |
+
<Gallery />
|
36 |
+
|
37 |
+
## Model description
|
38 |
+
|
39 |
+
This is a {training_type.lower()} finetune of model: `{args.pretrained_model_name_or_path}`.
|
40 |
+
|
41 |
+
The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers).
|
42 |
+
|
43 |
+
`id_token` used: {args.id_token} (if it's not `None`, it should be used in the prompts.)
|
44 |
+
|
45 |
+
## Download model
|
46 |
+
|
47 |
+
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
|
48 |
+
|
49 |
+
## Usage
|
50 |
+
|
51 |
+
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.
|
52 |
+
|
53 |
+
```py
|
54 |
+
TODO
|
55 |
+
```
|
56 |
+
|
57 |
+
For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
|
58 |
+
"""
|
59 |
+
if wandb.run and wandb.run.url:
|
60 |
+
model_description += f"""
|
61 |
+
Find out the wandb run URL and training configurations [here]({wandb.run.url}).
|
62 |
+
"""
|
63 |
+
|
64 |
+
model_card = load_or_create_model_card(
|
65 |
+
repo_id_or_path=repo_id,
|
66 |
+
from_training=True,
|
67 |
+
base_model=args.pretrained_model_name_or_path,
|
68 |
+
model_description=model_description,
|
69 |
+
widget=widget_dict,
|
70 |
+
)
|
71 |
+
tags = [
|
72 |
+
"text-to-video",
|
73 |
+
"diffusers-training",
|
74 |
+
"diffusers",
|
75 |
+
"finetrainers",
|
76 |
+
"template:sd-lora",
|
77 |
+
]
|
78 |
+
if training_type == "Full":
|
79 |
+
tags.append("full-finetune")
|
80 |
+
else:
|
81 |
+
tags.append("lora")
|
82 |
+
|
83 |
+
model_card = populate_model_card(model_card, tags=tags)
|
84 |
+
model_card.save(os.path.join(args.output_dir, "README.md"))
|
finetrainers/utils/memory_utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
from typing import Any, Dict, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from accelerate.logging import get_logger
|
6 |
+
|
7 |
+
|
8 |
+
logger = get_logger("finetrainers")
|
9 |
+
|
10 |
+
|
11 |
+
def get_memory_statistics(precision: int = 3) -> Dict[str, Any]:
|
12 |
+
memory_allocated = None
|
13 |
+
memory_reserved = None
|
14 |
+
max_memory_allocated = None
|
15 |
+
max_memory_reserved = None
|
16 |
+
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
device = torch.cuda.current_device()
|
19 |
+
memory_allocated = torch.cuda.memory_allocated(device)
|
20 |
+
memory_reserved = torch.cuda.memory_reserved(device)
|
21 |
+
max_memory_allocated = torch.cuda.max_memory_allocated(device)
|
22 |
+
max_memory_reserved = torch.cuda.max_memory_reserved(device)
|
23 |
+
|
24 |
+
elif torch.backends.mps.is_available():
|
25 |
+
memory_allocated = torch.mps.current_allocated_memory()
|
26 |
+
|
27 |
+
else:
|
28 |
+
logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
|
29 |
+
|
30 |
+
return {
|
31 |
+
"memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
|
32 |
+
"memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
|
33 |
+
"max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
|
34 |
+
"max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
def bytes_to_gigabytes(x: int) -> float:
|
39 |
+
if x is not None:
|
40 |
+
return x / 1024**3
|
41 |
+
|
42 |
+
|
43 |
+
def free_memory() -> None:
|
44 |
+
if torch.cuda.is_available():
|
45 |
+
gc.collect()
|
46 |
+
torch.cuda.empty_cache()
|
47 |
+
torch.cuda.ipc_collect()
|
48 |
+
|
49 |
+
# TODO(aryan): handle non-cuda devices
|
50 |
+
|
51 |
+
|
52 |
+
def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
53 |
+
if isinstance(x, torch.Tensor):
|
54 |
+
return x.contiguous()
|
55 |
+
elif isinstance(x, dict):
|
56 |
+
return {k: make_contiguous(v) for k, v in x.items()}
|
57 |
+
else:
|
58 |
+
return x
|
finetrainers/utils/model_utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
|
7 |
+
|
8 |
+
def resolve_vae_cls_from_ckpt_path(ckpt_path, **kwargs):
|
9 |
+
ckpt_path = str(ckpt_path)
|
10 |
+
if os.path.exists(str(ckpt_path)) and os.path.isdir(ckpt_path):
|
11 |
+
index_path = os.path.join(ckpt_path, "model_index.json")
|
12 |
+
else:
|
13 |
+
revision = kwargs.get("revision", None)
|
14 |
+
cache_dir = kwargs.get("cache_dir", None)
|
15 |
+
index_path = hf_hub_download(
|
16 |
+
repo_id=ckpt_path, filename="model_index.json", revision=revision, cache_dir=cache_dir
|
17 |
+
)
|
18 |
+
|
19 |
+
with open(index_path, "r") as f:
|
20 |
+
model_index_dict = json.load(f)
|
21 |
+
assert "vae" in model_index_dict, "No VAE found in the modelx index dict."
|
22 |
+
|
23 |
+
vae_cls_config = model_index_dict["vae"]
|
24 |
+
library = importlib.import_module(vae_cls_config[0])
|
25 |
+
return getattr(library, vae_cls_config[1])
|
finetrainers/utils/optimizer_utils.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from accelerate.logging import get_logger
|
5 |
+
|
6 |
+
|
7 |
+
logger = get_logger("finetrainers")
|
8 |
+
|
9 |
+
|
10 |
+
def get_optimizer(
|
11 |
+
params_to_optimize,
|
12 |
+
optimizer_name: str = "adam",
|
13 |
+
learning_rate: float = 1e-3,
|
14 |
+
beta1: float = 0.9,
|
15 |
+
beta2: float = 0.95,
|
16 |
+
beta3: float = 0.98,
|
17 |
+
epsilon: float = 1e-8,
|
18 |
+
weight_decay: float = 1e-4,
|
19 |
+
prodigy_decouple: bool = False,
|
20 |
+
prodigy_use_bias_correction: bool = False,
|
21 |
+
prodigy_safeguard_warmup: bool = False,
|
22 |
+
use_8bit: bool = False,
|
23 |
+
use_4bit: bool = False,
|
24 |
+
use_torchao: bool = False,
|
25 |
+
use_deepspeed: bool = False,
|
26 |
+
use_cpu_offload_optimizer: bool = False,
|
27 |
+
offload_gradients: bool = False,
|
28 |
+
) -> torch.optim.Optimizer:
|
29 |
+
optimizer_name = optimizer_name.lower()
|
30 |
+
|
31 |
+
# Use DeepSpeed optimzer
|
32 |
+
if use_deepspeed:
|
33 |
+
from accelerate.utils import DummyOptim
|
34 |
+
|
35 |
+
return DummyOptim(
|
36 |
+
params_to_optimize,
|
37 |
+
lr=learning_rate,
|
38 |
+
betas=(beta1, beta2),
|
39 |
+
eps=epsilon,
|
40 |
+
weight_decay=weight_decay,
|
41 |
+
)
|
42 |
+
|
43 |
+
# TODO: consider moving the validation logic to `args.py` when we have torchao.
|
44 |
+
if use_8bit and use_4bit:
|
45 |
+
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
|
46 |
+
|
47 |
+
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
|
48 |
+
try:
|
49 |
+
import torchao # noqa
|
50 |
+
|
51 |
+
except ImportError:
|
52 |
+
raise ImportError(
|
53 |
+
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
|
54 |
+
)
|
55 |
+
|
56 |
+
if not use_torchao and use_4bit:
|
57 |
+
raise ValueError("4-bit Optimizers are only supported with torchao.")
|
58 |
+
|
59 |
+
# Optimizer creation
|
60 |
+
supported_optimizers = ["adam", "adamw", "prodigy", "came"]
|
61 |
+
if optimizer_name not in supported_optimizers:
|
62 |
+
logger.warning(
|
63 |
+
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
|
64 |
+
)
|
65 |
+
optimizer_name = "adamw"
|
66 |
+
|
67 |
+
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
|
68 |
+
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
|
69 |
+
|
70 |
+
if use_8bit:
|
71 |
+
try:
|
72 |
+
import bitsandbytes as bnb
|
73 |
+
except ImportError:
|
74 |
+
raise ImportError(
|
75 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
76 |
+
)
|
77 |
+
|
78 |
+
if optimizer_name == "adamw":
|
79 |
+
if use_torchao:
|
80 |
+
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
81 |
+
|
82 |
+
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
|
83 |
+
else:
|
84 |
+
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
|
85 |
+
|
86 |
+
init_kwargs = {
|
87 |
+
"betas": (beta1, beta2),
|
88 |
+
"eps": epsilon,
|
89 |
+
"weight_decay": weight_decay,
|
90 |
+
}
|
91 |
+
|
92 |
+
elif optimizer_name == "adam":
|
93 |
+
if use_torchao:
|
94 |
+
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
|
95 |
+
|
96 |
+
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
|
97 |
+
else:
|
98 |
+
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
|
99 |
+
|
100 |
+
init_kwargs = {
|
101 |
+
"betas": (beta1, beta2),
|
102 |
+
"eps": epsilon,
|
103 |
+
"weight_decay": weight_decay,
|
104 |
+
}
|
105 |
+
|
106 |
+
elif optimizer_name == "prodigy":
|
107 |
+
try:
|
108 |
+
import prodigyopt
|
109 |
+
except ImportError:
|
110 |
+
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
111 |
+
|
112 |
+
optimizer_class = prodigyopt.Prodigy
|
113 |
+
|
114 |
+
if learning_rate <= 0.1:
|
115 |
+
logger.warning(
|
116 |
+
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
117 |
+
)
|
118 |
+
|
119 |
+
init_kwargs = {
|
120 |
+
"lr": learning_rate,
|
121 |
+
"betas": (beta1, beta2),
|
122 |
+
"beta3": beta3,
|
123 |
+
"eps": epsilon,
|
124 |
+
"weight_decay": weight_decay,
|
125 |
+
"decouple": prodigy_decouple,
|
126 |
+
"use_bias_correction": prodigy_use_bias_correction,
|
127 |
+
"safeguard_warmup": prodigy_safeguard_warmup,
|
128 |
+
}
|
129 |
+
|
130 |
+
elif optimizer_name == "came":
|
131 |
+
try:
|
132 |
+
import came_pytorch
|
133 |
+
except ImportError:
|
134 |
+
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
|
135 |
+
|
136 |
+
optimizer_class = came_pytorch.CAME
|
137 |
+
|
138 |
+
init_kwargs = {
|
139 |
+
"lr": learning_rate,
|
140 |
+
"eps": (1e-30, 1e-16),
|
141 |
+
"betas": (beta1, beta2, beta3),
|
142 |
+
"weight_decay": weight_decay,
|
143 |
+
}
|
144 |
+
|
145 |
+
if use_cpu_offload_optimizer:
|
146 |
+
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
|
147 |
+
|
148 |
+
if "fused" in inspect.signature(optimizer_class.__init__).parameters:
|
149 |
+
init_kwargs.update({"fused": True})
|
150 |
+
|
151 |
+
optimizer = CPUOffloadOptimizer(
|
152 |
+
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
|
156 |
+
|
157 |
+
return optimizer
|
158 |
+
|
159 |
+
|
160 |
+
def gradient_norm(parameters):
|
161 |
+
norm = 0
|
162 |
+
for param in parameters:
|
163 |
+
if param.grad is None:
|
164 |
+
continue
|
165 |
+
local_norm = param.grad.detach().data.norm(2)
|
166 |
+
norm += local_norm.item() ** 2
|
167 |
+
norm = norm**0.5
|
168 |
+
return norm
|
169 |
+
|
170 |
+
|
171 |
+
def max_gradient(parameters):
|
172 |
+
max_grad_value = float("-inf")
|
173 |
+
for param in parameters:
|
174 |
+
if param.grad is None:
|
175 |
+
continue
|
176 |
+
local_max_grad = param.grad.detach().data.abs().max()
|
177 |
+
max_grad_value = max(max_grad_value, local_max_grad.item())
|
178 |
+
return max_grad_value
|
finetrainers/utils/torch_utils.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from accelerate import Accelerator
|
5 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
6 |
+
|
7 |
+
|
8 |
+
def unwrap_model(accelerator: Accelerator, model):
|
9 |
+
model = accelerator.unwrap_model(model)
|
10 |
+
model = model._orig_mod if is_compiled_module(model) else model
|
11 |
+
return model
|
12 |
+
|
13 |
+
|
14 |
+
def align_device_and_dtype(
|
15 |
+
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
16 |
+
device: Optional[torch.device] = None,
|
17 |
+
dtype: Optional[torch.dtype] = None,
|
18 |
+
):
|
19 |
+
if isinstance(x, torch.Tensor):
|
20 |
+
if device is not None:
|
21 |
+
x = x.to(device)
|
22 |
+
if dtype is not None:
|
23 |
+
x = x.to(dtype)
|
24 |
+
elif isinstance(x, dict):
|
25 |
+
if device is not None:
|
26 |
+
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
|
27 |
+
if dtype is not None:
|
28 |
+
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
def expand_tensor_dims(tensor, ndim):
|
33 |
+
while len(tensor.shape) < ndim:
|
34 |
+
tensor = tensor.unsqueeze(-1)
|
35 |
+
return tensor
|
finetrainers_utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from pathlib import Path
|
3 |
+
import logging
|
4 |
+
import shutil
|
5 |
+
from typing import Any, Optional, Dict, List, Union, Tuple
|
6 |
+
from config import STORAGE_PATH, TRAINING_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
|
7 |
+
from utils import extract_scene_info, make_archive, is_image_file, is_video_file
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
def prepare_finetrainers_dataset() -> Tuple[Path, Path]:
|
12 |
+
"""make sure we have a Finetrainers-compatible dataset structure
|
13 |
+
|
14 |
+
Checks that we have:
|
15 |
+
training/
|
16 |
+
├── prompt.txt # All captions, one per line
|
17 |
+
├── videos.txt # All video paths, one per line
|
18 |
+
└── videos/ # Directory containing all mp4 files
|
19 |
+
├── 00000.mp4
|
20 |
+
├── 00001.mp4
|
21 |
+
└── ...
|
22 |
+
Returns:
|
23 |
+
Tuple of (videos_file_path, prompts_file_path)
|
24 |
+
"""
|
25 |
+
|
26 |
+
# Verifies the videos subdirectory
|
27 |
+
TRAINING_VIDEOS_PATH.mkdir(exist_ok=True)
|
28 |
+
|
29 |
+
# Clear existing training lists
|
30 |
+
for f in TRAINING_PATH.glob("*"):
|
31 |
+
if f.is_file():
|
32 |
+
if f.name in ["videos.txt", "prompts.txt"]:
|
33 |
+
f.unlink()
|
34 |
+
|
35 |
+
videos_file = TRAINING_PATH / "videos.txt"
|
36 |
+
prompts_file = TRAINING_PATH / "prompts.txt" # Note: Changed from prompt.txt to prompts.txt to match our config
|
37 |
+
|
38 |
+
media_files = []
|
39 |
+
captions = []
|
40 |
+
# Process all video files from the videos subdirectory
|
41 |
+
for idx, file in enumerate(sorted(TRAINING_VIDEOS_PATH.glob("*.mp4"))):
|
42 |
+
caption_file = file.with_suffix('.txt')
|
43 |
+
if caption_file.exists():
|
44 |
+
# Normalize caption to single line
|
45 |
+
caption = caption_file.read_text().strip()
|
46 |
+
caption = ' '.join(caption.split())
|
47 |
+
|
48 |
+
# Use relative path from training root
|
49 |
+
relative_path = f"videos/{file.name}"
|
50 |
+
media_files.append(relative_path)
|
51 |
+
captions.append(caption)
|
52 |
+
|
53 |
+
# Clean up the caption file since it's now in prompts.txt
|
54 |
+
# EDIT well you know what, let's keep it, otherwise running the function
|
55 |
+
# twice might cause some errors
|
56 |
+
# caption_file.unlink()
|
57 |
+
|
58 |
+
# Write files if we have content
|
59 |
+
if media_files and captions:
|
60 |
+
videos_file.write_text('\n'.join(media_files))
|
61 |
+
prompts_file.write_text('\n'.join(captions))
|
62 |
+
|
63 |
+
else:
|
64 |
+
raise ValueError("No valid video/caption pairs found in training directory")
|
65 |
+
# Verify file contents
|
66 |
+
with open(videos_file) as vf:
|
67 |
+
video_lines = [l.strip() for l in vf.readlines() if l.strip()]
|
68 |
+
with open(prompts_file) as pf:
|
69 |
+
prompt_lines = [l.strip() for l in pf.readlines() if l.strip()]
|
70 |
+
|
71 |
+
if len(video_lines) != len(prompt_lines):
|
72 |
+
raise ValueError(f"Mismatch in generated files: {len(video_lines)} videos vs {len(prompt_lines)} prompts")
|
73 |
+
|
74 |
+
return videos_file, prompts_file
|
75 |
+
|
76 |
+
def copy_files_to_training_dir(prompt_prefix: str) -> int:
|
77 |
+
"""Just copy files over, with no destruction"""
|
78 |
+
|
79 |
+
gr.Info("Copying assets to the training dataset..")
|
80 |
+
|
81 |
+
# Find files needing captions
|
82 |
+
video_files = list(STAGING_PATH.glob("*.mp4"))
|
83 |
+
image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)]
|
84 |
+
all_files = video_files + image_files
|
85 |
+
|
86 |
+
nb_copied_pairs = 0
|
87 |
+
|
88 |
+
for file_path in all_files:
|
89 |
+
|
90 |
+
caption = ""
|
91 |
+
file_caption_path = file_path.with_suffix('.txt')
|
92 |
+
if file_caption_path.exists():
|
93 |
+
logger.debug(f"Found caption file: {file_caption_path}")
|
94 |
+
caption = file_caption_path.read_text()
|
95 |
+
|
96 |
+
# Get parent caption if this is a clip
|
97 |
+
parent_caption = ""
|
98 |
+
if "___" in file_path.stem:
|
99 |
+
parent_name, _ = extract_scene_info(file_path.stem)
|
100 |
+
#print(f"parent_name is {parent_name}")
|
101 |
+
parent_caption_path = STAGING_PATH / f"{parent_name}.txt"
|
102 |
+
if parent_caption_path.exists():
|
103 |
+
logger.debug(f"Found parent caption file: {parent_caption_path}")
|
104 |
+
parent_caption = parent_caption_path.read_text().strip()
|
105 |
+
|
106 |
+
target_file_path = TRAINING_VIDEOS_PATH / file_path.name
|
107 |
+
|
108 |
+
target_caption_path = target_file_path.with_suffix('.txt')
|
109 |
+
|
110 |
+
if parent_caption and not caption.endswith(parent_caption):
|
111 |
+
caption = f"{caption}\n{parent_caption}"
|
112 |
+
|
113 |
+
if prompt_prefix and not caption.startswith(prompt_prefix):
|
114 |
+
caption = f"{prompt_prefix}{caption}"
|
115 |
+
|
116 |
+
# make sure we only copy over VALID pairs
|
117 |
+
if caption:
|
118 |
+
target_caption_path.write_text(caption)
|
119 |
+
shutil.copy2(file_path, target_file_path)
|
120 |
+
nb_copied_pairs += 1
|
121 |
+
|
122 |
+
prepare_finetrainers_dataset()
|
123 |
+
|
124 |
+
gr.Info(f"Successfully generated the training dataset ({nb_copied_pairs} pairs)")
|
125 |
+
|
126 |
+
return nb_copied_pairs
|
image_preprocessing.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from pathlib import Path
|
4 |
+
from PIL import Image
|
5 |
+
import pillow_avif
|
6 |
+
import logging
|
7 |
+
from config import NORMALIZE_IMAGES_TO, JPEG_QUALITY
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
def normalize_image(input_path: Path, output_path: Path) -> bool:
|
12 |
+
"""Convert image to normalized format (PNG or JPEG) and optionally remove black bars
|
13 |
+
|
14 |
+
Args:
|
15 |
+
input_path: Source image path
|
16 |
+
output_path: Target path
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
bool: True if successful, False otherwise
|
20 |
+
"""
|
21 |
+
try:
|
22 |
+
# Open image with PIL
|
23 |
+
with Image.open(input_path) as img:
|
24 |
+
# Convert to RGB if needed
|
25 |
+
if img.mode in ('RGBA', 'LA'):
|
26 |
+
background = Image.new('RGB', img.size, (255, 255, 255))
|
27 |
+
if img.mode == 'RGBA':
|
28 |
+
background.paste(img, mask=img.split()[3])
|
29 |
+
else:
|
30 |
+
background.paste(img, mask=img.split()[1])
|
31 |
+
img = background
|
32 |
+
elif img.mode != 'RGB':
|
33 |
+
img = img.convert('RGB')
|
34 |
+
|
35 |
+
# Convert to numpy for black bar detection
|
36 |
+
img_np = np.array(img)
|
37 |
+
|
38 |
+
# Detect black bars
|
39 |
+
top, bottom, left, right = detect_black_bars(img_np)
|
40 |
+
|
41 |
+
# Crop if black bars detected
|
42 |
+
if any([top > 0, bottom < img_np.shape[0] - 1,
|
43 |
+
left > 0, right < img_np.shape[1] - 1]):
|
44 |
+
img = img.crop((left, top, right, bottom))
|
45 |
+
|
46 |
+
# Save as configured format
|
47 |
+
if NORMALIZE_IMAGES_TO == 'png':
|
48 |
+
img.save(output_path, 'PNG', optimize=True)
|
49 |
+
else: # jpg
|
50 |
+
img.save(output_path, 'JPEG', quality=JPEG_QUALITY, optimize=True)
|
51 |
+
return True
|
52 |
+
|
53 |
+
except Exception as e:
|
54 |
+
logger.error(f"Error converting image {input_path}: {str(e)}")
|
55 |
+
return False
|
56 |
+
|
57 |
+
def detect_black_bars(img: np.ndarray) -> tuple[int, int, int, int]:
|
58 |
+
"""Detect black bars in image
|
59 |
+
|
60 |
+
Args:
|
61 |
+
img: numpy array of image (HxWxC)
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Tuple of (top, bottom, left, right) crop coordinates
|
65 |
+
"""
|
66 |
+
# Convert to grayscale if needed
|
67 |
+
if len(img.shape) == 3:
|
68 |
+
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
69 |
+
else:
|
70 |
+
gray = img
|
71 |
+
|
72 |
+
# Threshold to detect black regions
|
73 |
+
threshold = 20
|
74 |
+
black_mask = gray < threshold
|
75 |
+
|
76 |
+
# Find black bars by analyzing row/column means
|
77 |
+
row_means = np.mean(black_mask, axis=1)
|
78 |
+
col_means = np.mean(black_mask, axis=0)
|
79 |
+
|
80 |
+
# Detect edges where black bars end (95% threshold)
|
81 |
+
black_threshold = 0.95
|
82 |
+
|
83 |
+
# Find top and bottom crops
|
84 |
+
top = 0
|
85 |
+
bottom = img.shape[0]
|
86 |
+
|
87 |
+
for i, mean in enumerate(row_means):
|
88 |
+
if mean > black_threshold:
|
89 |
+
top = i + 1
|
90 |
+
else:
|
91 |
+
break
|
92 |
+
|
93 |
+
for i, mean in enumerate(reversed(row_means)):
|
94 |
+
if mean > black_threshold:
|
95 |
+
bottom = img.shape[0] - i - 1
|
96 |
+
else:
|
97 |
+
break
|
98 |
+
|
99 |
+
# Find left and right crops
|
100 |
+
left = 0
|
101 |
+
right = img.shape[1]
|
102 |
+
|
103 |
+
for i, mean in enumerate(col_means):
|
104 |
+
if mean > black_threshold:
|
105 |
+
left = i + 1
|
106 |
+
else:
|
107 |
+
break
|
108 |
+
|
109 |
+
for i, mean in enumerate(reversed(col_means)):
|
110 |
+
if mean > black_threshold:
|
111 |
+
right = img.shape[1] - i - 1
|
112 |
+
else:
|
113 |
+
break
|
114 |
+
|
115 |
+
return top, bottom, left, right
|
116 |
+
|
import_service.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import zipfile
|
4 |
+
import tempfile
|
5 |
+
import gradio as gr
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import List, Dict, Optional, Tuple
|
8 |
+
from pytubefix import YouTube
|
9 |
+
import logging
|
10 |
+
from utils import is_image_file, is_video_file, add_prefix_to_caption
|
11 |
+
from image_preprocessing import normalize_image
|
12 |
+
|
13 |
+
from config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, TRAINING_PATH, DEFAULT_PROMPT_PREFIX
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
class ImportService:
|
18 |
+
def process_uploaded_files(self, file_paths: List[str]) -> str:
|
19 |
+
"""Process uploaded file (ZIP, MP4, or image)
|
20 |
+
|
21 |
+
Args:
|
22 |
+
file_paths: File paths to the ploaded files from Gradio
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
Status message string
|
26 |
+
"""
|
27 |
+
for file_path in file_paths:
|
28 |
+
file_path = Path(file_path)
|
29 |
+
try:
|
30 |
+
original_name = file_path.name
|
31 |
+
print("original_name = ", original_name)
|
32 |
+
|
33 |
+
# Determine file type from name
|
34 |
+
file_ext = file_path.suffix.lower()
|
35 |
+
|
36 |
+
if file_ext == '.zip':
|
37 |
+
return self.process_zip_file(file_path)
|
38 |
+
elif file_ext == '.mp4' or file_ext == '.webm':
|
39 |
+
return self.process_mp4_file(file_path, original_name)
|
40 |
+
elif is_image_file(file_path):
|
41 |
+
return self.process_image_file(file_path, original_name)
|
42 |
+
else:
|
43 |
+
raise gr.Error(f"Unsupported file type: {file_ext}")
|
44 |
+
|
45 |
+
except Exception as e:
|
46 |
+
raise gr.Error(f"Error processing file: {str(e)}")
|
47 |
+
|
48 |
+
def process_image_file(self, file_path: Path, original_name: str) -> str:
|
49 |
+
"""Process a single image file
|
50 |
+
|
51 |
+
Args:
|
52 |
+
file_path: Path to the image
|
53 |
+
original_name: Original filename
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
Status message string
|
57 |
+
"""
|
58 |
+
try:
|
59 |
+
# Create a unique filename with configured extension
|
60 |
+
stem = Path(original_name).stem
|
61 |
+
target_path = STAGING_PATH / f"{stem}.{NORMALIZE_IMAGES_TO}"
|
62 |
+
|
63 |
+
# If file already exists, add number suffix
|
64 |
+
counter = 1
|
65 |
+
while target_path.exists():
|
66 |
+
target_path = STAGING_PATH / f"{stem}___{counter}.{NORMALIZE_IMAGES_TO}"
|
67 |
+
counter += 1
|
68 |
+
|
69 |
+
# Convert to normalized format and remove black bars
|
70 |
+
success = normalize_image(file_path, target_path)
|
71 |
+
|
72 |
+
if not success:
|
73 |
+
raise gr.Error(f"Failed to process image: {original_name}")
|
74 |
+
|
75 |
+
# Handle caption
|
76 |
+
src_caption_path = file_path.with_suffix('.txt')
|
77 |
+
if src_caption_path.exists():
|
78 |
+
caption = src_caption_path.read_text()
|
79 |
+
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
|
80 |
+
target_path.with_suffix('.txt').write_text(caption)
|
81 |
+
|
82 |
+
logger.info(f"Successfully stored image: {target_path.name}")
|
83 |
+
gr.Info(f"Successfully stored image: {target_path.name}")
|
84 |
+
return f"Successfully stored image: {target_path.name}"
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
+
raise gr.Error(f"Error processing image file: {str(e)}")
|
88 |
+
|
89 |
+
def process_zip_file(self, file_path: Path) -> str:
|
90 |
+
"""Process uploaded ZIP file containing media files
|
91 |
+
|
92 |
+
Args:
|
93 |
+
file_path: Path to the uploaded ZIP file
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
Status message string
|
97 |
+
"""
|
98 |
+
try:
|
99 |
+
video_count = 0
|
100 |
+
image_count = 0
|
101 |
+
|
102 |
+
# Create temporary directory
|
103 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
104 |
+
# Extract ZIP
|
105 |
+
extract_dir = Path(temp_dir) / "extracted"
|
106 |
+
extract_dir.mkdir()
|
107 |
+
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
108 |
+
zip_ref.extractall(extract_dir)
|
109 |
+
|
110 |
+
# Process each file
|
111 |
+
for root, _, files in os.walk(extract_dir):
|
112 |
+
for file in files:
|
113 |
+
if file.startswith('._'): # Skip Mac metadata
|
114 |
+
continue
|
115 |
+
|
116 |
+
file_path = Path(root) / file
|
117 |
+
|
118 |
+
try:
|
119 |
+
if is_video_file(file_path):
|
120 |
+
# Copy video to videos_to_split
|
121 |
+
target_path = VIDEOS_TO_SPLIT_PATH / file_path.name
|
122 |
+
counter = 1
|
123 |
+
while target_path.exists():
|
124 |
+
target_path = VIDEOS_TO_SPLIT_PATH / f"{file_path.stem}___{counter}{file_path.suffix}"
|
125 |
+
counter += 1
|
126 |
+
shutil.copy2(file_path, target_path)
|
127 |
+
video_count += 1
|
128 |
+
|
129 |
+
elif is_image_file(file_path):
|
130 |
+
# Convert image and save to staging
|
131 |
+
target_path = STAGING_PATH / f"{file_path.stem}.{NORMALIZE_IMAGES_TO}"
|
132 |
+
counter = 1
|
133 |
+
while target_path.exists():
|
134 |
+
target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
|
135 |
+
counter += 1
|
136 |
+
if normalize_image(file_path, target_path):
|
137 |
+
image_count += 1
|
138 |
+
|
139 |
+
# Copy associated caption file if it exists
|
140 |
+
txt_path = file_path.with_suffix('.txt')
|
141 |
+
if txt_path.exists():
|
142 |
+
if is_video_file(file_path):
|
143 |
+
shutil.copy2(txt_path, target_path.with_suffix('.txt'))
|
144 |
+
elif is_image_file(file_path):
|
145 |
+
shutil.copy2(txt_path, target_path.with_suffix('.txt'))
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
logger.error(f"Error processing {file_path.name}: {str(e)}")
|
149 |
+
continue
|
150 |
+
|
151 |
+
# Generate status message
|
152 |
+
parts = []
|
153 |
+
if video_count > 0:
|
154 |
+
parts.append(f"{video_count} videos")
|
155 |
+
if image_count > 0:
|
156 |
+
parts.append(f"{image_count} images")
|
157 |
+
|
158 |
+
if not parts:
|
159 |
+
return "No supported media files found in ZIP"
|
160 |
+
|
161 |
+
status = f"Successfully stored {' and '.join(parts)}"
|
162 |
+
gr.Info(status)
|
163 |
+
return status
|
164 |
+
|
165 |
+
except Exception as e:
|
166 |
+
raise gr.Error(f"Error processing ZIP: {str(e)}")
|
167 |
+
|
168 |
+
def process_mp4_file(self, file_path: Path, original_name: str) -> str:
|
169 |
+
"""Process a single video file
|
170 |
+
|
171 |
+
Args:
|
172 |
+
file_path: Path to the file
|
173 |
+
original_name: Original filename
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
Status message string
|
177 |
+
"""
|
178 |
+
try:
|
179 |
+
# Create a unique filename
|
180 |
+
target_path = VIDEOS_TO_SPLIT_PATH / original_name
|
181 |
+
|
182 |
+
# If file already exists, add number suffix
|
183 |
+
counter = 1
|
184 |
+
while target_path.exists():
|
185 |
+
stem = Path(original_name).stem
|
186 |
+
target_path = VIDEOS_TO_SPLIT_PATH / f"{stem}___{counter}.mp4"
|
187 |
+
counter += 1
|
188 |
+
|
189 |
+
# Copy the file to the target location
|
190 |
+
shutil.copy2(file_path, target_path)
|
191 |
+
|
192 |
+
gr.Info(f"Successfully stored video: {target_path.name}")
|
193 |
+
return f"Successfully stored video: {target_path.name}"
|
194 |
+
|
195 |
+
except Exception as e:
|
196 |
+
raise gr.Error(f"Error processing video file: {str(e)}")
|
197 |
+
|
198 |
+
def download_youtube_video(self, url: str, progress=None) -> Dict:
|
199 |
+
"""Download a video from YouTube
|
200 |
+
|
201 |
+
Args:
|
202 |
+
url: YouTube video URL
|
203 |
+
progress: Optional Gradio progress indicator
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
Dict with status message and error (if any)
|
207 |
+
"""
|
208 |
+
try:
|
209 |
+
# Extract video ID and create YouTube object
|
210 |
+
yt = YouTube(url, on_progress_callback=lambda stream, chunk, bytes_remaining:
|
211 |
+
progress((1 - bytes_remaining / stream.filesize), desc="Downloading...")
|
212 |
+
if progress else None)
|
213 |
+
|
214 |
+
video_id = yt.video_id
|
215 |
+
output_path = VIDEOS_TO_SPLIT_PATH / f"{video_id}.mp4"
|
216 |
+
|
217 |
+
# Download highest quality progressive MP4
|
218 |
+
if progress:
|
219 |
+
print("Getting video streams...")
|
220 |
+
progress(0, desc="Getting video streams...")
|
221 |
+
video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
|
222 |
+
|
223 |
+
if not video:
|
224 |
+
print("Could not find a compatible video format")
|
225 |
+
gr.Error("Could not find a compatible video format")
|
226 |
+
return "Could not find a compatible video format"
|
227 |
+
|
228 |
+
# Download the video
|
229 |
+
if progress:
|
230 |
+
print("Starting YouTube video download...")
|
231 |
+
progress(0, desc="Starting download...")
|
232 |
+
|
233 |
+
video.download(output_path=str(VIDEOS_TO_SPLIT_PATH), filename=f"{video_id}.mp4")
|
234 |
+
|
235 |
+
# Update UI
|
236 |
+
if progress:
|
237 |
+
print("YouTube video download complete!")
|
238 |
+
gr.Info("YouTube video download complete!")
|
239 |
+
progress(1, desc="Download complete!")
|
240 |
+
return f"Successfully downloaded video: {yt.title}"
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
print(e)
|
244 |
+
gr.Error(f"Error downloading video: {str(e)}")
|
245 |
+
return f"Error downloading video: {str(e)}"
|
requirements.txt
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.26.4
|
2 |
+
|
3 |
+
# to quote a-r-r-o-w/finetrainers:
|
4 |
+
# It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
|
5 |
+
torch==2.5.1
|
6 |
+
torchvision==0.20.1
|
7 |
+
torchao==0.6.1
|
8 |
+
|
9 |
+
huggingface_hub
|
10 |
+
hf_transfer>=0.1.8
|
11 |
+
diffusers>=0.30.3
|
12 |
+
transformers>=4.45.2
|
13 |
+
|
14 |
+
accelerate
|
15 |
+
bitsandbytes
|
16 |
+
peft>=0.12.0
|
17 |
+
eva-decord==0.6.1
|
18 |
+
wandb
|
19 |
+
pandas
|
20 |
+
sentencepiece>=0.2.0
|
21 |
+
imageio-ffmpeg>=0.5.1
|
22 |
+
|
23 |
+
flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
24 |
+
|
25 |
+
# for youtube video download
|
26 |
+
pytube
|
27 |
+
pytubefix
|
28 |
+
|
29 |
+
# for scene splitting
|
30 |
+
scenedetect[opencv]
|
31 |
+
|
32 |
+
# for llava video / captionning
|
33 |
+
pillow
|
34 |
+
pillow-avif-plugin
|
35 |
+
polars
|
36 |
+
einops
|
37 |
+
open_clip_torch
|
38 |
+
av==14.1.0
|
39 |
+
git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
40 |
+
|
41 |
+
# for our frontend
|
42 |
+
gradio==5.15.0
|
43 |
+
gradio_toggle
|
requirements_without_flash_attention.txt
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.26.4
|
2 |
+
|
3 |
+
# to quote a-r-r-o-w/finetrainers:
|
4 |
+
# It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
|
5 |
+
torch==2.5.1
|
6 |
+
torchvision==0.20.1
|
7 |
+
torchao==0.6.1
|
8 |
+
|
9 |
+
|
10 |
+
huggingface_hub
|
11 |
+
hf_transfer>=0.1.8
|
12 |
+
diffusers>=0.30.3
|
13 |
+
transformers>=4.45.2
|
14 |
+
|
15 |
+
accelerate
|
16 |
+
bitsandbytes
|
17 |
+
peft>=0.12.0
|
18 |
+
eva-decord==0.6.1
|
19 |
+
wandb
|
20 |
+
pandas
|
21 |
+
sentencepiece>=0.2.0
|
22 |
+
imageio-ffmpeg>=0.5.1
|
23 |
+
|
24 |
+
# for youtube video download
|
25 |
+
pytube
|
26 |
+
pytubefix
|
27 |
+
|
28 |
+
# for scene splitting
|
29 |
+
scenedetect[opencv]
|
30 |
+
|
31 |
+
# for llava video / captionning
|
32 |
+
pillow
|
33 |
+
pillow-avif-plugin
|
34 |
+
polars
|
35 |
+
einops
|
36 |
+
open_clip_torch
|
37 |
+
av==14.1.0
|
38 |
+
git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
39 |
+
|
40 |
+
# for our frontend
|
41 |
+
gradio==5.15.0
|
42 |
+
gradio_toggle
|
run.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
source .venv/bin/activate
|
4 |
+
|
5 |
+
USE_MOCK_CAPTIONING_MODEL=True python app.py
|
setup.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
python -m venv .venv
|
4 |
+
|
5 |
+
source .venv/bin/activate
|
6 |
+
|
7 |
+
python -m pip install -r requirements.txt
|
setup_no_captions.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
python -m venv .venv
|
4 |
+
|
5 |
+
source .venv/bin/activate
|
6 |
+
|
7 |
+
python -m pip install -r requirements_without_flash_attention.txt
|
8 |
+
|
9 |
+
# if you require flash attention, please install it manually for your operating system
|
10 |
+
|
11 |
+
# you can try this:
|
12 |
+
# python -m pip install wheel setuptools flash-attn --no-build-isolation --no-cache-dir
|