jbilcke-hf HF Staff commited on
Commit
7c52128
·
1 Parent(s): c6546ad

making our code more robust

Browse files
docs/finetrainers/documentation_models_cogvideox.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CogVideoX
2
+
3
+ ## Training
4
+
5
+ For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
+
7
+ Examples available:
8
+ - [PIKA crush effect](../../examples/training/sft/cogvideox/crush_smol_lora/)
9
+
10
+ To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL):
11
+
12
+ ```bash
13
+ chmod +x ./examples/training/sft/cogvideox/crush_smol_lora/train.sh
14
+ ./examples/training/sft/cogvideox/crush_smol_lora/train.sh
15
+ ```
16
+
17
+ On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
18
+
19
+ ## Supported checkpoints
20
+
21
+ CogVideoX has multiple checkpoints as one can note [here](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce). The following checkpoints were tested with `finetrainers` and are known to be working:
22
+
23
+ * [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b)
24
+ * [THUDM/CogVideoX-5B](https://huggingface.co/THUDM/CogVideoX-5B)
25
+ * [THUDM/CogVideoX1.5-5B](https://huggingface.co/THUDM/CogVideoX1.5-5B)
26
+
27
+ ## Inference
28
+
29
+ Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:
30
+
31
+ ```diff
32
+ import torch
33
+ from diffusers import CogVideoXPipeline
34
+ from diffusers.utils import export_to_video
35
+
36
+ pipe = CogVideoXPipeline.from_pretrained(
37
+ "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16
38
+ ).to("cuda")
39
+ + pipe.load_lora_weights("my-awesome-name/my-awesome-lora", adapter_name="cogvideox-lora")
40
+ + pipe.set_adapters(["cogvideox-lora"], [0.75])
41
+
42
+ video = pipe("<my-awesome-prompt>").frames[0]
43
+ export_to_video(video, "output.mp4")
44
+ ```
45
+
46
+ You can refer to the following guides to know more about the model pipeline and performing LoRA inference in `diffusers`:
47
+
48
+ * [CogVideoX in Diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox)
49
+ * [Load LoRAs for inference](https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference)
50
+ * [Merge LoRAs](https://huggingface.co/docs/diffusers/main/en/using-diffusers/merge_loras)
docs/finetrainers/documentation_models_hunyuan_video.md CHANGED
@@ -4,151 +4,17 @@
4
 
5
  For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
 
7
- ```bash
8
- #!/bin/bash
9
-
10
- export WANDB_MODE="offline"
11
- export NCCL_P2P_DISABLE=1
12
- export TORCH_NCCL_ENABLE_MONITORING=0
13
- export FINETRAINERS_LOG_LEVEL=DEBUG
14
-
15
- GPU_IDS="0,1"
16
-
17
- DATA_ROOT="/path/to/dataset"
18
- CAPTION_COLUMN="prompts.txt"
19
- VIDEO_COLUMN="videos.txt"
20
- OUTPUT_DIR="/path/to/models/hunyuan-video/"
21
-
22
- ID_TOKEN="afkx"
23
-
24
- # Model arguments
25
- model_cmd="--model_name hunyuan_video \
26
- --pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo"
27
-
28
- # Dataset arguments
29
- dataset_cmd="--data_root $DATA_ROOT \
30
- --video_column $VIDEO_COLUMN \
31
- --caption_column $CAPTION_COLUMN \
32
- --id_token $ID_TOKEN \
33
- --video_resolution_buckets 17x512x768 49x512x768 61x512x768 \
34
- --caption_dropout_p 0.05"
35
-
36
- # Dataloader arguments
37
- dataloader_cmd="--dataloader_num_workers 0"
38
-
39
- # Diffusion arguments
40
- diffusion_cmd=""
41
-
42
- # Training arguments
43
- training_cmd="--training_type lora \
44
- --seed 42 \
45
- --batch_size 1 \
46
- --train_steps 500 \
47
- --rank 128 \
48
- --lora_alpha 128 \
49
- --target_modules to_q to_k to_v to_out.0 \
50
- --gradient_accumulation_steps 1 \
51
- --gradient_checkpointing \
52
- --checkpointing_steps 500 \
53
- --checkpointing_limit 2 \
54
- --enable_slicing \
55
- --enable_tiling"
56
-
57
- # Optimizer arguments
58
- optimizer_cmd="--optimizer adamw \
59
- --lr 2e-5 \
60
- --lr_scheduler constant_with_warmup \
61
- --lr_warmup_steps 100 \
62
- --lr_num_cycles 1 \
63
- --beta1 0.9 \
64
- --beta2 0.95 \
65
- --weight_decay 1e-4 \
66
- --epsilon 1e-8 \
67
- --max_grad_norm 1.0"
68
-
69
- # Miscellaneous arguments
70
- miscellaneous_cmd="--tracker_name finetrainers-hunyuan-video \
71
- --output_dir $OUTPUT_DIR \
72
- --nccl_timeout 1800 \
73
- --report_to wandb"
74
-
75
- cmd="accelerate launch --config_file accelerate_configs/uncompiled_8.yaml --gpu_ids $GPU_IDS train.py \
76
- $model_cmd \
77
- $dataset_cmd \
78
- $dataloader_cmd \
79
- $diffusion_cmd \
80
- $training_cmd \
81
- $optimizer_cmd \
82
- $miscellaneous_cmd"
83
-
84
- echo "Running command: $cmd"
85
- eval $cmd
86
- echo -ne "-------------------- Finished executing script --------------------\n\n"
87
- ```
88
 
89
- ## Memory Usage
90
 
91
- ### LoRA
92
-
93
- > [!NOTE]
94
- >
95
- > The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
96
-
97
- LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**:
98
-
99
- ```
100
- Training configuration: {
101
- "trainable parameters": 163577856,
102
- "total samples": 69,
103
- "train epochs": 1,
104
- "train steps": 10,
105
- "batches per device": 1,
106
- "total batches observed per epoch": 69,
107
- "train batch size": 1,
108
- "gradient accumulation steps": 1
109
- }
110
- ```
111
-
112
- | stage | memory_allocated | max_memory_reserved |
113
- |:-----------------------:|:----------------:|:-------------------:|
114
- | before training start | 38.889 | 39.020 |
115
- | before validation start | 39.747 | 56.266 |
116
- | after validation end | 39.748 | 58.385 |
117
- | after epoch 1 | 39.748 | 40.910 |
118
- | after training end | 25.288 | 40.910 |
119
-
120
- Note: requires about `59` GB of VRAM when validation is performed.
121
-
122
- LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **with precomputation**:
123
-
124
- ```
125
- Training configuration: {
126
- "trainable parameters": 163577856,
127
- "total samples": 1,
128
- "train epochs": 10,
129
- "train steps": 10,
130
- "batches per device": 1,
131
- "total batches observed per epoch": 1,
132
- "train batch size": 1,
133
- "gradient accumulation steps": 1
134
- }
135
  ```
136
 
137
- | stage | memory_allocated | max_memory_reserved |
138
- |:-----------------------------:|:----------------:|:-------------------:|
139
- | after precomputing conditions | 14.232 | 14.461 |
140
- | after precomputing latents | 14.717 | 17.244 |
141
- | before training start | 24.195 | 26.039 |
142
- | after epoch 1 | 24.83 | 42.387 |
143
- | before validation start | 24.842 | 42.387 |
144
- | after validation end | 39.558 | 46.947 |
145
- | after training end | 24.842 | 41.039 |
146
-
147
- Note: requires about `47` GB of VRAM with validation. If validation is not performed, the memory usage is reduced to about `42` GB.
148
-
149
- ### Full finetuning
150
-
151
- Current, full finetuning is not supported for HunyuanVideo. It goes out of memory (OOM) for `49x512x768` resolutions.
152
 
153
  ## Inference
154
 
 
4
 
5
  For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
 
7
+ Examples available:
8
+ - [PIKA Dissolve effect](../../examples/training/sft/hunyuan_video/modal_labs_dissolve/)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL):
11
 
12
+ ```bash
13
+ chmod +x ./examples/training/sft/hunyuan_video/modal_labs_dissolve/train.sh
14
+ ./examples/training/sft/hunyuan_video/modal_labs_dissolve/train.sh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  ```
16
 
17
+ On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  ## Inference
20
 
docs/finetrainers/documentation_models_ltx_video.md CHANGED
@@ -4,171 +4,17 @@
4
 
5
  For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
 
7
- ```bash
8
- #!/bin/bash
9
- export WANDB_MODE="offline"
10
- export NCCL_P2P_DISABLE=1
11
- export TORCH_NCCL_ENABLE_MONITORING=0
12
- export FINETRAINERS_LOG_LEVEL=DEBUG
13
-
14
- GPU_IDS="0,1"
15
-
16
- DATA_ROOT="/path/to/dataset"
17
- CAPTION_COLUMN="prompts.txt"
18
- VIDEO_COLUMN="videos.txt"
19
- OUTPUT_DIR="/path/to/models/ltx-video/"
20
-
21
- ID_TOKEN="BW_STYLE"
22
-
23
- # Model arguments
24
- model_cmd="--model_name ltx_video \
25
- --pretrained_model_name_or_path Lightricks/LTX-Video"
26
-
27
- # Dataset arguments
28
- dataset_cmd="--data_root $DATA_ROOT \
29
- --video_column $VIDEO_COLUMN \
30
- --caption_column $CAPTION_COLUMN \
31
- --id_token $ID_TOKEN \
32
- --video_resolution_buckets 49x512x768 \
33
- --caption_dropout_p 0.05"
34
-
35
- # Dataloader arguments
36
- dataloader_cmd="--dataloader_num_workers 0"
37
-
38
- # Diffusion arguments
39
- diffusion_cmd="--flow_weighting_scheme logit_normal"
40
-
41
- # Training arguments
42
- training_cmd="--training_type lora \
43
- --seed 42 \
44
- --batch_size 1 \
45
- --train_steps 3000 \
46
- --rank 128 \
47
- --lora_alpha 128 \
48
- --target_modules to_q to_k to_v to_out.0 \
49
- --gradient_accumulation_steps 4 \
50
- --gradient_checkpointing \
51
- --checkpointing_steps 500 \
52
- --checkpointing_limit 2 \
53
- --enable_slicing \
54
- --enable_tiling"
55
-
56
- # Optimizer arguments
57
- optimizer_cmd="--optimizer adamw \
58
- --lr 3e-5 \
59
- --lr_scheduler constant_with_warmup \
60
- --lr_warmup_steps 100 \
61
- --lr_num_cycles 1 \
62
- --beta1 0.9 \
63
- --beta2 0.95 \
64
- --weight_decay 1e-4 \
65
- --epsilon 1e-8 \
66
- --max_grad_norm 1.0"
67
-
68
- # Miscellaneous arguments
69
- miscellaneous_cmd="--tracker_name finetrainers-ltxv \
70
- --output_dir $OUTPUT_DIR \
71
- --nccl_timeout 1800 \
72
- --report_to wandb"
73
-
74
- cmd="accelerate launch --config_file accelerate_configs/uncompiled_2.yaml --gpu_ids $GPU_IDS train.py \
75
- $model_cmd \
76
- $dataset_cmd \
77
- $dataloader_cmd \
78
- $diffusion_cmd \
79
- $training_cmd \
80
- $optimizer_cmd \
81
- $miscellaneous_cmd"
82
-
83
- echo "Running command: $cmd"
84
- eval $cmd
85
- echo -ne "-------------------- Finished executing script --------------------\n\n"
86
- ```
87
-
88
- ## Memory Usage
89
-
90
- ### LoRA
91
 
92
- > [!NOTE]
93
- >
94
- > The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).
95
 
96
- LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**:
97
-
98
- ```
99
- Training configuration: {
100
- "trainable parameters": 117440512,
101
- "total samples": 69,
102
- "train epochs": 1,
103
- "train steps": 10,
104
- "batches per device": 1,
105
- "total batches observed per epoch": 69,
106
- "train batch size": 1,
107
- "gradient accumulation steps": 1
108
- }
109
- ```
110
-
111
- | stage | memory_allocated | max_memory_reserved |
112
- |:-----------------------:|:----------------:|:-------------------:|
113
- | before training start | 13.486 | 13.879 |
114
- | before validation start | 14.146 | 17.623 |
115
- | after validation end | 14.146 | 17.623 |
116
- | after epoch 1 | 14.146 | 17.623 |
117
- | after training end | 4.461 | 17.623 |
118
-
119
- Note: requires about `18` GB of VRAM without precomputation.
120
-
121
- LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **with precomputation**:
122
-
123
- ```
124
- Training configuration: {
125
- "trainable parameters": 117440512,
126
- "total samples": 1,
127
- "train epochs": 10,
128
- "train steps": 10,
129
- "batches per device": 1,
130
- "total batches observed per epoch": 1,
131
- "train batch size": 1,
132
- "gradient accumulation steps": 1
133
- }
134
- ```
135
-
136
- | stage | memory_allocated | max_memory_reserved |
137
- |:-----------------------------:|:----------------:|:-------------------:|
138
- | after precomputing conditions | 8.88 | 8.920 |
139
- | after precomputing latents | 9.684 | 11.613 |
140
- | before training start | 3.809 | 10.010 |
141
- | after epoch 1 | 4.26 | 10.916 |
142
- | before validation start | 4.26 | 10.916 |
143
- | after validation end | 13.924 | 17.262 |
144
- | after training end | 4.26 | 14.314 |
145
-
146
- Note: requires about `17.5` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to `11` GB.
147
-
148
- ### Full Finetuning
149
-
150
- ```
151
- Training configuration: {
152
- "trainable parameters": 1923385472,
153
- "total samples": 1,
154
- "train epochs": 10,
155
- "train steps": 10,
156
- "batches per device": 1,
157
- "total batches observed per epoch": 1,
158
- "train batch size": 1,
159
- "gradient accumulation steps": 1
160
- }
161
  ```
162
 
163
- | stage | memory_allocated | max_memory_reserved |
164
- |:-----------------------------:|:----------------:|:-------------------:|
165
- | after precomputing conditions | 8.89 | 8.937 |
166
- | after precomputing latents | 9.701 | 11.615 |
167
- | before training start | 3.583 | 4.025 |
168
- | after epoch 1 | 10.769 | 20.357 |
169
- | before validation start | 10.769 | 20.357 |
170
- | after validation end | 10.769 | 28.332 |
171
- | after training end | 10.769 | 12.904 |
172
 
173
  ## Inference
174
 
 
4
 
5
  For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
 
7
+ Examples available:
8
+ - [PIKA crush effect](../../examples/training/sft/ltx_video/crush_smol_lora/)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL):
 
 
11
 
12
+ ```bash
13
+ chmod +x ./examples/training/sft/ltx_video/crush_smol_lora/train.sh
14
+ ./examples/training/sft/ltx_video/crush_smol_lora/train.sh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  ```
16
 
17
+ On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
 
 
 
 
 
 
 
 
18
 
19
  ## Inference
20
 
docs/finetrainers/documentation_models_wan.md CHANGED
@@ -4,11 +4,18 @@
4
 
5
  For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
 
7
- See [this](../../examples/training/sft/wan/crush_smol_lora/) example training script for training Wan with Pika Effects Crush.
 
 
8
 
9
- ## Memory Usage
10
 
11
- TODO
 
 
 
 
 
12
 
13
  ## Inference
14
 
 
4
 
5
  For LoRA training, specify `--training_type lora`. For full finetuning, specify `--training_type full-finetune`.
6
 
7
+ Examples available:
8
+ - [PIKA crush effect](../../examples/training/sft/wan/crush_smol_lora/)
9
+ - [3DGS dissolve](../../examples/training/sft/wan/3dgs_dissolve/)
10
 
11
+ To run an example, run the following from the root directory of the repository (assuming you have installed the requirements and are using Linux/WSL):
12
 
13
+ ```bash
14
+ chmod +x ./examples/training/sft/wan/crush_smol_lora/train.sh
15
+ ./examples/training/sft/wan/crush_smol_lora/train.sh
16
+ ```
17
+
18
+ On Windows, you will have to modify the script to a compatible format to run it. [TODO(aryan): improve instructions for Windows]
19
 
20
  ## Inference
21
 
vms/config.py CHANGED
@@ -2,6 +2,8 @@ import os
2
  from dataclasses import dataclass, field
3
  from typing import Dict, Any, Optional, List, Tuple
4
  from pathlib import Path
 
 
5
 
6
  def parse_bool_env(env_value: Optional[str]) -> bool:
7
  """Parse environment variable string to boolean
@@ -71,7 +73,16 @@ TRAINING_TYPES = {
71
 
72
  DEFAULT_SEED = 42
73
 
74
- DEFAULT_NB_TRAINING_STEPS = 1000
 
 
 
 
 
 
 
 
 
75
 
76
  DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS = 200
77
 
@@ -87,6 +98,23 @@ DEFAULT_BATCH_SIZE = 1
87
 
88
  DEFAULT_LEARNING_RATE = 3e-5
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  # it is best to use resolutions that are powers of 8
91
  # The resolution should be divisible by 32
92
  # so we cannot use 1080, 540 etc as they are not divisible by 32
@@ -183,7 +211,10 @@ TRAINING_PRESETS = {
183
  "learning_rate": 2e-5,
184
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
185
  "training_buckets": SMALL_TRAINING_BUCKETS,
186
- "flow_weighting_scheme": "none"
 
 
 
187
  },
188
  "LTX-Video (normal)": {
189
  "model_type": "ltx_video",
@@ -195,7 +226,10 @@ TRAINING_PRESETS = {
195
  "learning_rate": DEFAULT_LEARNING_RATE,
196
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
197
  "training_buckets": SMALL_TRAINING_BUCKETS,
198
- "flow_weighting_scheme": "logit_normal"
 
 
 
199
  },
200
  "LTX-Video (16:9, HQ)": {
201
  "model_type": "ltx_video",
@@ -207,7 +241,10 @@ TRAINING_PRESETS = {
207
  "learning_rate": DEFAULT_LEARNING_RATE,
208
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
209
  "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
210
- "flow_weighting_scheme": "logit_normal"
 
 
 
211
  },
212
  "LTX-Video (Full Finetune)": {
213
  "model_type": "ltx_video",
@@ -217,7 +254,10 @@ TRAINING_PRESETS = {
217
  "learning_rate": DEFAULT_LEARNING_RATE,
218
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
219
  "training_buckets": SMALL_TRAINING_BUCKETS,
220
- "flow_weighting_scheme": "logit_normal"
 
 
 
221
  },
222
  "Wan-2.1-T2V (normal)": {
223
  "model_type": "wan",
@@ -229,7 +269,10 @@ TRAINING_PRESETS = {
229
  "learning_rate": 5e-5,
230
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
231
  "training_buckets": SMALL_TRAINING_BUCKETS,
232
- "flow_weighting_scheme": "logit_normal"
 
 
 
233
  },
234
  "Wan-2.1-T2V (HQ)": {
235
  "model_type": "wan",
@@ -241,7 +284,10 @@ TRAINING_PRESETS = {
241
  "learning_rate": DEFAULT_LEARNING_RATE,
242
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
243
  "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
244
- "flow_weighting_scheme": "logit_normal"
 
 
 
245
  }
246
  }
247
 
@@ -287,7 +333,7 @@ class TrainingConfig:
287
  seed: int = DEFAULT_SEED
288
  mixed_precision: str = "bf16"
289
  batch_size: int = 1
290
- train_step: int = DEFAULT_NB_TRAINING_STEPS
291
  lora_rank: int = DEFAULT_LORA_RANK
292
  lora_alpha: int = DEFAULT_LORA_ALPHA
293
  target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"])
@@ -301,10 +347,10 @@ class TrainingConfig:
301
 
302
  # Optimizer arguments
303
  optimizer: str = "adamw"
304
- lr: float = 3e-5
305
  scale_lr: bool = False
306
  lr_scheduler: str = "constant_with_warmup"
307
- lr_warmup_steps: int = 100
308
  lr_num_cycles: int = 1
309
  lr_power: float = 1.0
310
  beta1: float = 0.9
 
2
  from dataclasses import dataclass, field
3
  from typing import Dict, Any, Optional, List, Tuple
4
  from pathlib import Path
5
+ import torch
6
+ import math
7
 
8
  def parse_bool_env(env_value: Optional[str]) -> bool:
9
  """Parse environment variable string to boolean
 
73
 
74
  DEFAULT_SEED = 42
75
 
76
+ DEFAULT_REMOVE_COMMON_LLM_CAPTION_PREFIXES = True
77
+
78
+ DEFAULT_DATASET_TYPE = "video"
79
+ DEFAULT_TRAINING_TYPE = "lora"
80
+
81
+ DEFAULT_RESHAPE_MODE = "bicubic"
82
+
83
+ DEFAULT_MIXED_PRECISION = "bf16"
84
+
85
+
86
 
87
  DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS = 200
88
 
 
98
 
99
  DEFAULT_LEARNING_RATE = 3e-5
100
 
101
+ # GPU SETTINGS
102
+ DEFAULT_NUM_GPUS = 1
103
+ DEFAULT_MAX_GPUS = min(8, torch.cuda.device_count() if torch.cuda.is_available() else 1)
104
+ DEFAULT_PRECOMPUTATION_ITEMS = 512
105
+
106
+ DEFAULT_NB_TRAINING_STEPS = 1000
107
+
108
+ # For this value, it is recommended to use about 20 to 40% of the number of training steps
109
+ DEFAULT_NB_LR_WARMUP_STEPS = math.ceil(0.20 * DEFAULT_NB_TRAINING_STEPS) # 20% of training steps
110
+
111
+ # For validation
112
+ DEFAULT_VALIDATION_NB_STEPS = 50
113
+ DEFAULT_VALIDATION_HEIGHT = 512
114
+ DEFAULT_VALIDATION_WIDTH = 768
115
+ DEFAULT_VALIDATION_NB_FRAMES = 49
116
+ DEFAULT_VALIDATION_FRAMERATE = 8
117
+
118
  # it is best to use resolutions that are powers of 8
119
  # The resolution should be divisible by 32
120
  # so we cannot use 1080, 540 etc as they are not divisible by 32
 
211
  "learning_rate": 2e-5,
212
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
213
  "training_buckets": SMALL_TRAINING_BUCKETS,
214
+ "flow_weighting_scheme": "none",
215
+ "num_gpus": DEFAULT_NUM_GPUS,
216
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
217
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
218
  },
219
  "LTX-Video (normal)": {
220
  "model_type": "ltx_video",
 
226
  "learning_rate": DEFAULT_LEARNING_RATE,
227
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
228
  "training_buckets": SMALL_TRAINING_BUCKETS,
229
+ "flow_weighting_scheme": "none",
230
+ "num_gpus": DEFAULT_NUM_GPUS,
231
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
232
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
233
  },
234
  "LTX-Video (16:9, HQ)": {
235
  "model_type": "ltx_video",
 
241
  "learning_rate": DEFAULT_LEARNING_RATE,
242
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
243
  "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
244
+ "flow_weighting_scheme": "logit_normal",
245
+ "num_gpus": DEFAULT_NUM_GPUS,
246
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
247
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
248
  },
249
  "LTX-Video (Full Finetune)": {
250
  "model_type": "ltx_video",
 
254
  "learning_rate": DEFAULT_LEARNING_RATE,
255
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
256
  "training_buckets": SMALL_TRAINING_BUCKETS,
257
+ "flow_weighting_scheme": "logit_normal",
258
+ "num_gpus": DEFAULT_NUM_GPUS,
259
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
260
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
261
  },
262
  "Wan-2.1-T2V (normal)": {
263
  "model_type": "wan",
 
269
  "learning_rate": 5e-5,
270
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
271
  "training_buckets": SMALL_TRAINING_BUCKETS,
272
+ "flow_weighting_scheme": "logit_normal",
273
+ "num_gpus": DEFAULT_NUM_GPUS,
274
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
275
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
276
  },
277
  "Wan-2.1-T2V (HQ)": {
278
  "model_type": "wan",
 
284
  "learning_rate": DEFAULT_LEARNING_RATE,
285
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
286
  "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
287
+ "flow_weighting_scheme": "logit_normal",
288
+ "num_gpus": DEFAULT_NUM_GPUS,
289
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
290
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS,
291
  }
292
  }
293
 
 
333
  seed: int = DEFAULT_SEED
334
  mixed_precision: str = "bf16"
335
  batch_size: int = 1
336
+ train_steps: int = DEFAULT_NB_TRAINING_STEPS
337
  lora_rank: int = DEFAULT_LORA_RANK
338
  lora_alpha: int = DEFAULT_LORA_ALPHA
339
  target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"])
 
347
 
348
  # Optimizer arguments
349
  optimizer: str = "adamw"
350
+ lr: float = DEFAULT_LEARNING_RATE
351
  scale_lr: bool = False
352
  lr_scheduler: str = "constant_with_warmup"
353
+ lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS
354
  lr_num_cycles: int = 1
355
  lr_power: float = 1.0
356
  beta1: float = 0.9
vms/services/trainer.py CHANGED
@@ -28,9 +28,26 @@ from ..config import (
28
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
29
  DEFAULT_LEARNING_RATE,
30
  DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
31
- DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
- from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
34
 
35
  logger = logging.getLogger(__name__)
36
 
@@ -107,18 +124,89 @@ class TrainingService:
107
 
108
 
109
  def save_ui_state(self, values: Dict[str, Any]) -> None:
110
- """Save current UI state to file"""
111
  ui_state_file = OUTPUT_PATH / "ui_state.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  try:
 
 
 
 
113
  with open(ui_state_file, 'w') as f:
114
- json.dump(values, f, indent=2)
115
- logger.debug(f"UI state saved: {values}")
116
  except Exception as e:
117
  logger.error(f"Error saving UI state: {str(e)}")
118
 
119
- # Additional fix for the load_ui_state method in trainer.py to clean up old values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def load_ui_state(self) -> Dict[str, Any]:
121
- """Load saved UI state"""
122
  ui_state_file = OUTPUT_PATH / "ui_state.json"
123
  default_state = {
124
  "model_type": list(MODEL_TYPES.keys())[0],
@@ -129,7 +217,10 @@ class TrainingService:
129
  "batch_size": DEFAULT_BATCH_SIZE,
130
  "learning_rate": DEFAULT_LEARNING_RATE,
131
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
132
- "training_preset": list(TRAINING_PRESETS.keys())[0]
 
 
 
133
  }
134
 
135
  if not ui_state_file.exists():
@@ -149,7 +240,13 @@ class TrainingService:
149
  logger.warning("UI state file is empty or contains only whitespace, using default values")
150
  return default_state
151
 
152
- saved_state = json.loads(file_content)
 
 
 
 
 
 
153
 
154
  # Clean up model type if it contains " (LoRA)" suffix
155
  if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]:
@@ -158,17 +255,36 @@ class TrainingService:
158
 
159
  # Convert numeric values to appropriate types
160
  if "train_steps" in saved_state:
161
- saved_state["train_steps"] = int(saved_state["train_steps"])
 
 
 
 
 
162
  if "batch_size" in saved_state:
163
- saved_state["batch_size"] = int(saved_state["batch_size"])
 
 
 
 
 
164
  if "learning_rate" in saved_state:
165
- saved_state["learning_rate"] = float(saved_state["learning_rate"])
 
 
 
 
 
166
  if "save_iterations" in saved_state:
167
- saved_state["save_iterations"] = int(saved_state["save_iterations"])
 
 
 
 
168
 
169
  # Make sure we have all keys (in case structure changed)
170
  merged_state = default_state.copy()
171
- merged_state.update(saved_state)
172
 
173
  # Validate model_type is in available choices
174
  if merged_state["model_type"] not in MODEL_TYPES:
@@ -203,67 +319,80 @@ class TrainingService:
203
  merged_state["training_preset"] = default_state["training_preset"]
204
  logger.warning(f"Invalid training preset in saved state, using default")
205
 
 
 
 
 
 
 
 
 
 
 
206
  return merged_state
207
- except json.JSONDecodeError as e:
208
- logger.error(f"Error parsing UI state JSON: {str(e)}")
209
- return default_state
210
  except Exception as e:
211
  logger.error(f"Error loading UI state: {str(e)}")
 
 
212
  return default_state
213
 
214
  def ensure_valid_ui_state_file(self):
215
  """Ensure UI state file exists and is valid JSON"""
216
  ui_state_file = OUTPUT_PATH / "ui_state.json"
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  if not ui_state_file.exists():
219
- # Create a new file with default values
220
  logger.info("Creating new UI state file with default values")
221
- default_state = {
222
- "model_type": list(MODEL_TYPES.keys())[0],
223
- "training_type": list(TRAINING_TYPES.keys())[0],
224
- "lora_rank": DEFAULT_LORA_RANK_STR,
225
- "lora_alpha": DEFAULT_LORA_ALPHA_STR,
226
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
227
- "batch_size": DEFAULT_BATCH_SIZE,
228
- "learning_rate": DEFAULT_LEARNING_RATE,
229
- "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
230
- "training_preset": list(TRAINING_PRESETS.keys())[0]
231
- }
232
  self.save_ui_state(default_state)
233
  return
234
 
235
  # Check if file is valid JSON
236
  try:
 
 
 
 
 
 
 
237
  with open(ui_state_file, 'r') as f:
238
  file_content = f.read().strip()
239
  if not file_content:
240
- raise ValueError("Empty file")
241
- json.loads(file_content)
242
- logger.debug("UI state file validation successful")
 
 
 
 
 
 
 
 
 
 
243
  except Exception as e:
244
- logger.warning(f"Invalid UI state file detected: {str(e)}. Creating new one with defaults.")
245
- # Backup the invalid file
246
- backup_file = ui_state_file.with_suffix('.json.bak')
247
- try:
248
- shutil.copy2(ui_state_file, backup_file)
249
- logger.info(f"Backed up invalid UI state file to {backup_file}")
250
- except Exception as backup_error:
251
- logger.error(f"Failed to backup invalid UI state file: {str(backup_error)}")
252
-
253
- # Create a new file with default values
254
- default_state = {
255
- "model_type": list(MODEL_TYPES.keys())[0],
256
- "training_type": list(TRAINING_TYPES.keys())[0],
257
- "lora_rank": DEFAULT_LORA_RANK_STR,
258
- "lora_alpha": DEFAULT_LORA_ALPHA_STR,
259
- "train_steps": DEFAULT_NB_TRAINING_STEPS,
260
- "batch_size": DEFAULT_BATCH_SIZE,
261
- "learning_rate": DEFAULT_LEARNING_RATE,
262
- "save_iterations": DEFAULT_NB_TRAINING_STEPS,
263
- "training_preset": list(TRAINING_PRESETS.keys())[0]
264
- }
265
- self.save_ui_state(default_state)
266
-
267
  # Modify save_session to also store the UI state at training start
268
  def save_session(self, params: Dict) -> None:
269
  """Save training session parameters"""
@@ -412,8 +541,12 @@ class TrainingService:
412
  save_iterations: int,
413
  repo_id: str,
414
  preset_name: str,
415
- training_type: str = "lora",
416
  resume_from_checkpoint: Optional[str] = None,
 
 
 
 
417
  ) -> Tuple[str, str]:
418
  """Start training with finetrainers"""
419
 
@@ -431,6 +564,10 @@ class TrainingService:
431
  log_prefix = "Resuming" if is_resuming else "Initializing"
432
  logger.info(f"{log_prefix} training with model_type={model_type}, training_type={training_type}")
433
 
 
 
 
 
434
  try:
435
  # Get absolute paths - FIXED to look in project root instead of within vms directory
436
  current_dir = Path(__file__).parent.parent.parent.absolute() # Go up to project root
@@ -459,6 +596,10 @@ class TrainingService:
459
  logger.info("Current working directory: %s", current_dir)
460
  logger.info("Training script path: %s", train_script)
461
  logger.info("Training data path: %s", TRAINING_PATH)
 
 
 
 
462
 
463
  videos_file, prompts_file = prepare_finetrainers_dataset()
464
  if videos_file is None or prompts_file is None:
@@ -474,32 +615,45 @@ class TrainingService:
474
  logger.error(error_msg)
475
  return error_msg, "No training data available"
476
 
 
 
 
 
477
  # Get preset configuration
478
  preset = TRAINING_PRESETS[preset_name]
479
  training_buckets = preset["training_buckets"]
480
  flow_weighting_scheme = preset.get("flow_weighting_scheme", "none")
481
  preset_training_type = preset.get("training_type", "lora")
482
 
 
 
 
 
 
 
 
 
 
 
483
  # Create a proper dataset configuration JSON file
484
  dataset_config_file = OUTPUT_PATH / "dataset_config.json"
485
 
486
- # Determine appropriate ID token based on model type
487
- id_token = None
488
- if model_type == "hunyuan_video":
489
- id_token = "afkx"
490
- elif model_type == "ltx_video":
491
- id_token = "BW_STYLE"
492
- # Wan doesn't use an ID token by default, so leave it as None
493
 
494
  dataset_config = {
495
  "datasets": [
496
  {
497
  "data_root": str(TRAINING_PATH),
498
- "dataset_type": "video",
499
  "id_token": id_token,
500
  "video_resolution_buckets": [[f, h, w] for f, h, w in training_buckets],
501
- "reshape_mode": "bicubic",
502
- "remove_common_llm_caption_prefixes": True
503
  }
504
  ]
505
  }
@@ -552,6 +706,16 @@ class TrainingService:
552
  logger.error(error_msg)
553
  return error_msg, "Unsupported model"
554
 
 
 
 
 
 
 
 
 
 
 
555
  # Update with UI parameters
556
  config.train_steps = int(train_steps)
557
  config.batch_size = int(batch_size)
@@ -560,7 +724,19 @@ class TrainingService:
560
  config.training_type = training_type
561
  config.flow_weighting_scheme = flow_weighting_scheme
562
 
563
- # CRITICAL FIX: Update the dataset_config to point to the JSON file, not the directory
 
 
 
 
 
 
 
 
 
 
 
 
564
  config.data_root = str(dataset_config_file)
565
 
566
  # Update LoRA parameters if using LoRA training type
@@ -574,7 +750,7 @@ class TrainingService:
574
  self.append_log(f"Resuming from checkpoint: {resume_from_checkpoint}")
575
 
576
  # Common settings for both models
577
- config.mixed_precision = "bf16"
578
  config.seed = DEFAULT_SEED
579
  config.gradient_checkpointing = True
580
  config.enable_slicing = True
@@ -598,7 +774,7 @@ class TrainingService:
598
  torchrun_args = [
599
  "torchrun",
600
  "--standalone",
601
- "--nproc_per_node=1",
602
  "--nnodes=1",
603
  "--rdzv_backend=c10d",
604
  "--rdzv_endpoint=localhost:0",
@@ -623,11 +799,29 @@ class TrainingService:
623
  launch_args = torchrun_args
624
  else:
625
  # For other models, use accelerate launch as before
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
626
  # Configure accelerate parameters
627
  accelerate_args = [
628
  "accelerate", "launch",
 
 
629
  "--mixed_precision=bf16",
630
- "--num_processes=1",
631
  "--num_machines=1",
632
  "--dynamo_backend=no",
633
  str(train_script)
@@ -647,7 +841,11 @@ class TrainingService:
647
  env["WANDB_MODE"] = "offline"
648
  env["HF_API_TOKEN"] = HF_API_TOKEN
649
  env["FINETRAINERS_LOG_LEVEL"] = "DEBUG" # Added for better debugging
650
-
 
 
 
 
651
  # Start the training process
652
  process = subprocess.Popen(
653
  launch_args + config_args,
@@ -675,6 +873,9 @@ class TrainingService:
675
  "batch_size": batch_size,
676
  "learning_rate": learning_rate,
677
  "save_iterations": save_iterations,
 
 
 
678
  "repo_id": repo_id,
679
  "start_time": datetime.now().isoformat()
680
  })
@@ -699,6 +900,10 @@ class TrainingService:
699
  self.append_log(success_msg)
700
  logger.info(success_msg)
701
 
 
 
 
 
702
  return success_msg, self.get_logs()
703
 
704
  except Exception as e:
@@ -1064,19 +1269,28 @@ class TrainingService:
1064
  if output:
1065
  # Remove decode() since output is already a string due to universal_newlines=True
1066
  line = output.strip()
 
1067
  if is_error:
1068
- #self.append_log(f"ERROR: {line}")
1069
  #logger.error(line)
1070
- #logger.info(line)
1071
- self.append_log(line)
1072
- else:
1073
- self.append_log(line)
1074
- # Parse metrics only from stdout
1075
- metrics = parse_training_log(line)
1076
- if metrics:
1077
- status = self.get_status()
1078
- status.update(metrics)
1079
- self.save_status(**status)
 
 
 
 
 
 
 
 
 
1080
  return True
1081
  return False
1082
 
 
28
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
29
  DEFAULT_LEARNING_RATE,
30
  DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
31
+ DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR,
32
+ DEFAULT_SEED, DEFAULT_RESHAPE_MODE,
33
+ DEFAULT_REMOVE_COMMON_LLM_CAPTION_PREFIXES,
34
+ DEFAULT_DATASET_TYPE, DEFAULT_PROMPT_PREFIX,
35
+ DEFAULT_MIXED_PRECISION, DEFAULT_TRAINING_TYPE,
36
+ DEFAULT_NUM_GPUS,
37
+ DEFAULT_MAX_GPUS,
38
+ DEFAULT_PRECOMPUTATION_ITEMS,
39
+ DEFAULT_NB_TRAINING_STEPS,
40
+ DEFAULT_NB_LR_WARMUP_STEPS
41
+ )
42
+ from ..utils import (
43
+ get_available_gpu_count,
44
+ make_archive,
45
+ parse_training_log,
46
+ is_image_file,
47
+ is_video_file,
48
+ prepare_finetrainers_dataset,
49
+ copy_files_to_training_dir
50
  )
 
51
 
52
  logger = logging.getLogger(__name__)
53
 
 
124
 
125
 
126
  def save_ui_state(self, values: Dict[str, Any]) -> None:
127
+ """Save current UI state to file with validation"""
128
  ui_state_file = OUTPUT_PATH / "ui_state.json"
129
+
130
+ # Validate values before saving
131
+ validated_values = {}
132
+ default_state = {
133
+ "model_type": list(MODEL_TYPES.keys())[0],
134
+ "training_type": list(TRAINING_TYPES.keys())[0],
135
+ "lora_rank": DEFAULT_LORA_RANK_STR,
136
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR,
137
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
138
+ "batch_size": DEFAULT_BATCH_SIZE,
139
+ "learning_rate": DEFAULT_LEARNING_RATE,
140
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
141
+ "training_preset": list(TRAINING_PRESETS.keys())[0],
142
+ "num_gpus": DEFAULT_NUM_GPUS,
143
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
144
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
145
+ }
146
+
147
+ # Copy default values first
148
+ validated_values = default_state.copy()
149
+
150
+ # Update with provided values, converting types as needed
151
+ for key, value in values.items():
152
+ if key in default_state:
153
+ if key == "train_steps":
154
+ try:
155
+ validated_values[key] = int(value)
156
+ except (ValueError, TypeError):
157
+ validated_values[key] = default_state[key]
158
+ elif key == "batch_size":
159
+ try:
160
+ validated_values[key] = int(value)
161
+ except (ValueError, TypeError):
162
+ validated_values[key] = default_state[key]
163
+ elif key == "learning_rate":
164
+ try:
165
+ validated_values[key] = float(value)
166
+ except (ValueError, TypeError):
167
+ validated_values[key] = default_state[key]
168
+ elif key == "save_iterations":
169
+ try:
170
+ validated_values[key] = int(value)
171
+ except (ValueError, TypeError):
172
+ validated_values[key] = default_state[key]
173
+ elif key == "lora_rank" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
174
+ validated_values[key] = default_state[key]
175
+ elif key == "lora_alpha" and value not in ["16", "32", "64", "128", "256", "512", "1024"]:
176
+ validated_values[key] = default_state[key]
177
+ else:
178
+ validated_values[key] = value
179
+
180
  try:
181
+ # First verify we can serialize to JSON
182
+ json_data = json.dumps(validated_values, indent=2)
183
+
184
+ # Write to the file
185
  with open(ui_state_file, 'w') as f:
186
+ f.write(json_data)
187
+ logger.debug(f"UI state saved successfully")
188
  except Exception as e:
189
  logger.error(f"Error saving UI state: {str(e)}")
190
 
191
+ def _backup_and_recreate_ui_state(self, ui_state_file, default_state):
192
+ """Backup the corrupted UI state file and create a new one with defaults"""
193
+ try:
194
+ # Create a backup with timestamp
195
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
196
+ backup_file = ui_state_file.with_suffix(f'.json.bak_{timestamp}')
197
+
198
+ # Copy the corrupted file
199
+ shutil.copy2(ui_state_file, backup_file)
200
+ logger.info(f"Backed up corrupted UI state file to {backup_file}")
201
+ except Exception as backup_error:
202
+ logger.error(f"Failed to backup corrupted UI state file: {str(backup_error)}")
203
+
204
+ # Create a new file with default values
205
+ self.save_ui_state(default_state)
206
+ logger.info("Created new UI state file with default values after error")
207
+
208
  def load_ui_state(self) -> Dict[str, Any]:
209
+ """Load saved UI state with robust error handling"""
210
  ui_state_file = OUTPUT_PATH / "ui_state.json"
211
  default_state = {
212
  "model_type": list(MODEL_TYPES.keys())[0],
 
217
  "batch_size": DEFAULT_BATCH_SIZE,
218
  "learning_rate": DEFAULT_LEARNING_RATE,
219
  "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
220
+ "training_preset": list(TRAINING_PRESETS.keys())[0],
221
+ "num_gpus": DEFAULT_NUM_GPUS,
222
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
223
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
224
  }
225
 
226
  if not ui_state_file.exists():
 
240
  logger.warning("UI state file is empty or contains only whitespace, using default values")
241
  return default_state
242
 
243
+ try:
244
+ saved_state = json.loads(file_content)
245
+ except json.JSONDecodeError as e:
246
+ logger.error(f"Error parsing UI state JSON: {str(e)}")
247
+ # Instead of showing the error, recreate the file with defaults
248
+ self._backup_and_recreate_ui_state(ui_state_file, default_state)
249
+ return default_state
250
 
251
  # Clean up model type if it contains " (LoRA)" suffix
252
  if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]:
 
255
 
256
  # Convert numeric values to appropriate types
257
  if "train_steps" in saved_state:
258
+ try:
259
+ saved_state["train_steps"] = int(saved_state["train_steps"])
260
+ except (ValueError, TypeError):
261
+ saved_state["train_steps"] = default_state["train_steps"]
262
+ logger.warning("Invalid train_steps value, using default")
263
+
264
  if "batch_size" in saved_state:
265
+ try:
266
+ saved_state["batch_size"] = int(saved_state["batch_size"])
267
+ except (ValueError, TypeError):
268
+ saved_state["batch_size"] = default_state["batch_size"]
269
+ logger.warning("Invalid batch_size value, using default")
270
+
271
  if "learning_rate" in saved_state:
272
+ try:
273
+ saved_state["learning_rate"] = float(saved_state["learning_rate"])
274
+ except (ValueError, TypeError):
275
+ saved_state["learning_rate"] = default_state["learning_rate"]
276
+ logger.warning("Invalid learning_rate value, using default")
277
+
278
  if "save_iterations" in saved_state:
279
+ try:
280
+ saved_state["save_iterations"] = int(saved_state["save_iterations"])
281
+ except (ValueError, TypeError):
282
+ saved_state["save_iterations"] = default_state["save_iterations"]
283
+ logger.warning("Invalid save_iterations value, using default")
284
 
285
  # Make sure we have all keys (in case structure changed)
286
  merged_state = default_state.copy()
287
+ merged_state.update({k: v for k, v in saved_state.items() if v is not None})
288
 
289
  # Validate model_type is in available choices
290
  if merged_state["model_type"] not in MODEL_TYPES:
 
319
  merged_state["training_preset"] = default_state["training_preset"]
320
  logger.warning(f"Invalid training preset in saved state, using default")
321
 
322
+ # Validate lora_rank is in allowed values
323
+ if merged_state.get("lora_rank") not in ["16", "32", "64", "128", "256", "512", "1024"]:
324
+ merged_state["lora_rank"] = default_state["lora_rank"]
325
+ logger.warning(f"Invalid lora_rank in saved state, using default")
326
+
327
+ # Validate lora_alpha is in allowed values
328
+ if merged_state.get("lora_alpha") not in ["16", "32", "64", "128", "256", "512", "1024"]:
329
+ merged_state["lora_alpha"] = default_state["lora_alpha"]
330
+ logger.warning(f"Invalid lora_alpha in saved state, using default")
331
+
332
  return merged_state
 
 
 
333
  except Exception as e:
334
  logger.error(f"Error loading UI state: {str(e)}")
335
+ # If anything goes wrong, backup and recreate
336
+ self._backup_and_recreate_ui_state(ui_state_file, default_state)
337
  return default_state
338
 
339
  def ensure_valid_ui_state_file(self):
340
  """Ensure UI state file exists and is valid JSON"""
341
  ui_state_file = OUTPUT_PATH / "ui_state.json"
342
 
343
+ # Default state with all required values
344
+ default_state = {
345
+ "model_type": list(MODEL_TYPES.keys())[0],
346
+ "training_type": list(TRAINING_TYPES.keys())[0],
347
+ "lora_rank": DEFAULT_LORA_RANK_STR,
348
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR,
349
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
350
+ "batch_size": DEFAULT_BATCH_SIZE,
351
+ "learning_rate": DEFAULT_LEARNING_RATE,
352
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
353
+ "training_preset": list(TRAINING_PRESETS.keys())[0],
354
+ "num_gpus": DEFAULT_NUM_GPUS,
355
+ "precomputation_items": DEFAULT_PRECOMPUTATION_ITEMS,
356
+ "lr_warmup_steps": DEFAULT_NB_LR_WARMUP_STEPS
357
+ }
358
+
359
+ # If file doesn't exist, create it with default values
360
  if not ui_state_file.exists():
 
361
  logger.info("Creating new UI state file with default values")
 
 
 
 
 
 
 
 
 
 
 
362
  self.save_ui_state(default_state)
363
  return
364
 
365
  # Check if file is valid JSON
366
  try:
367
+ # First check if the file is empty
368
+ file_size = ui_state_file.stat().st_size
369
+ if file_size == 0:
370
+ logger.warning("UI state file exists but is empty, recreating with default values")
371
+ self.save_ui_state(default_state)
372
+ return
373
+
374
  with open(ui_state_file, 'r') as f:
375
  file_content = f.read().strip()
376
  if not file_content:
377
+ logger.warning("UI state file is empty or contains only whitespace, recreating with default values")
378
+ self.save_ui_state(default_state)
379
+ return
380
+
381
+ # Try to parse the JSON content
382
+ try:
383
+ saved_state = json.loads(file_content)
384
+ logger.debug("UI state file validation successful")
385
+ except json.JSONDecodeError as e:
386
+ # JSON parsing failed, backup and recreate
387
+ logger.error(f"Error parsing UI state JSON: {str(e)}")
388
+ self._backup_and_recreate_ui_state(ui_state_file, default_state)
389
+ return
390
  except Exception as e:
391
+ # Any other error (file access, etc)
392
+ logger.error(f"Error checking UI state file: {str(e)}")
393
+ self._backup_and_recreate_ui_state(ui_state_file, default_state)
394
+ return
395
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  # Modify save_session to also store the UI state at training start
397
  def save_session(self, params: Dict) -> None:
398
  """Save training session parameters"""
 
541
  save_iterations: int,
542
  repo_id: str,
543
  preset_name: str,
544
+ training_type: str = DEFAULT_TRAINING_TYPE,
545
  resume_from_checkpoint: Optional[str] = None,
546
+ num_gpus: int = DEFAULT_NUM_GPUS,
547
+ precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
548
+ lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
549
+ progress: Optional[gr.Progress] = None,
550
  ) -> Tuple[str, str]:
551
  """Start training with finetrainers"""
552
 
 
564
  log_prefix = "Resuming" if is_resuming else "Initializing"
565
  logger.info(f"{log_prefix} training with model_type={model_type}, training_type={training_type}")
566
 
567
+ # Update progress if available
568
+ if progress:
569
+ progress(0.15, desc="Setting up training configuration")
570
+
571
  try:
572
  # Get absolute paths - FIXED to look in project root instead of within vms directory
573
  current_dir = Path(__file__).parent.parent.parent.absolute() # Go up to project root
 
596
  logger.info("Current working directory: %s", current_dir)
597
  logger.info("Training script path: %s", train_script)
598
  logger.info("Training data path: %s", TRAINING_PATH)
599
+
600
+ # Update progress
601
+ if progress:
602
+ progress(0.2, desc="Preparing training dataset")
603
 
604
  videos_file, prompts_file = prepare_finetrainers_dataset()
605
  if videos_file is None or prompts_file is None:
 
615
  logger.error(error_msg)
616
  return error_msg, "No training data available"
617
 
618
+ # Update progress
619
+ if progress:
620
+ progress(0.25, desc="Creating dataset configuration")
621
+
622
  # Get preset configuration
623
  preset = TRAINING_PRESETS[preset_name]
624
  training_buckets = preset["training_buckets"]
625
  flow_weighting_scheme = preset.get("flow_weighting_scheme", "none")
626
  preset_training_type = preset.get("training_type", "lora")
627
 
628
+ # Get the custom prompt prefix from the tabs
629
+ custom_prompt_prefix = None
630
+ if hasattr(self.app, 'tabs') and 'caption_tab' in self.app.tabs:
631
+ if hasattr(self.app.tabs['caption_tab'], 'components') and 'custom_prompt_prefix' in self.app.tabs['caption_tab'].components:
632
+ # Get the value and clean it
633
+ prefix = self.app.tabs['caption_tab'].components['custom_prompt_prefix'].value
634
+ if prefix:
635
+ # Clean the prefix - remove trailing comma, space or comma+space
636
+ custom_prompt_prefix = prefix.rstrip(', ')
637
+
638
  # Create a proper dataset configuration JSON file
639
  dataset_config_file = OUTPUT_PATH / "dataset_config.json"
640
 
641
+ # Determine appropriate ID token based on model type and custom prefix
642
+ id_token = custom_prompt_prefix # Use custom prefix as the primary id_token
643
+
644
+ # Only use default ID tokens if no custom prefix is provided
645
+ if not id_token:
646
+ id_token = DEFAULT_PROMPT_PREFIX
 
647
 
648
  dataset_config = {
649
  "datasets": [
650
  {
651
  "data_root": str(TRAINING_PATH),
652
+ "dataset_type": DEFAULT_DATASET_TYPE,
653
  "id_token": id_token,
654
  "video_resolution_buckets": [[f, h, w] for f, h, w in training_buckets],
655
+ "reshape_mode": DEFAULT_RESHAPE_MODE,
656
+ "remove_common_llm_caption_prefixes": DEFAULT_REMOVE_COMMON_LLM_CAPTION_PREFIXES,
657
  }
658
  ]
659
  }
 
706
  logger.error(error_msg)
707
  return error_msg, "Unsupported model"
708
 
709
+ # Create validation dataset if needed
710
+ validation_file = None
711
+ #if enable_validation: # Add a parameter to control this
712
+ # validation_file = create_validation_config()
713
+ # if validation_file:
714
+ # config_args.extend([
715
+ # "--validation_dataset_file", str(validation_file),
716
+ # "--validation_steps", "500" # Set this to a suitable value
717
+ # ])
718
+
719
  # Update with UI parameters
720
  config.train_steps = int(train_steps)
721
  config.batch_size = int(batch_size)
 
724
  config.training_type = training_type
725
  config.flow_weighting_scheme = flow_weighting_scheme
726
 
727
+ config.lr_warmup_steps = int(lr_warmup_steps)
728
+ config_args.extend([
729
+ "--precomputation_items", str(precomputation_items)
730
+ ])
731
+
732
+ # Update the NUM_GPUS variable and CUDA_VISIBLE_DEVICES
733
+ num_gpus = min(num_gpus, get_available_gpu_count())
734
+ if num_gpus <= 0:
735
+ num_gpus = 1
736
+
737
+ # Generate CUDA_VISIBLE_DEVICES string
738
+ visible_devices = ",".join([str(i) for i in range(num_gpus)])
739
+
740
  config.data_root = str(dataset_config_file)
741
 
742
  # Update LoRA parameters if using LoRA training type
 
750
  self.append_log(f"Resuming from checkpoint: {resume_from_checkpoint}")
751
 
752
  # Common settings for both models
753
+ config.mixed_precision = DEFAULT_MIXED_PRECISION
754
  config.seed = DEFAULT_SEED
755
  config.gradient_checkpointing = True
756
  config.enable_slicing = True
 
774
  torchrun_args = [
775
  "torchrun",
776
  "--standalone",
777
+ "--nproc_per_node=" + str(num_gpus),
778
  "--nnodes=1",
779
  "--rdzv_backend=c10d",
780
  "--rdzv_endpoint=localhost:0",
 
799
  launch_args = torchrun_args
800
  else:
801
  # For other models, use accelerate launch as before
802
+ # Determine the appropriate accelerate config file based on num_gpus
803
+ accelerate_config = None
804
+ if num_gpus == 1:
805
+ accelerate_config = "accelerate_configs/uncompiled_1.yaml"
806
+ elif num_gpus == 2:
807
+ accelerate_config = "accelerate_configs/uncompiled_2.yaml"
808
+ elif num_gpus == 4:
809
+ accelerate_config = "accelerate_configs/uncompiled_4.yaml"
810
+ elif num_gpus == 8:
811
+ accelerate_config = "accelerate_configs/uncompiled_8.yaml"
812
+ else:
813
+ # Default to 1 GPU config if no matching config is found
814
+ accelerate_config = "accelerate_configs/uncompiled_1.yaml"
815
+ num_gpus = 1
816
+ visible_devices = "0"
817
+
818
  # Configure accelerate parameters
819
  accelerate_args = [
820
  "accelerate", "launch",
821
+ "--config_file", accelerate_config,
822
+ "--gpu_ids", visible_devices,
823
  "--mixed_precision=bf16",
824
+ "--num_processes=" + str(num_gpus),
825
  "--num_machines=1",
826
  "--dynamo_backend=no",
827
  str(train_script)
 
841
  env["WANDB_MODE"] = "offline"
842
  env["HF_API_TOKEN"] = HF_API_TOKEN
843
  env["FINETRAINERS_LOG_LEVEL"] = "DEBUG" # Added for better debugging
844
+ env["CUDA_VISIBLE_DEVICES"] = visible_devices
845
+
846
+ if progress:
847
+ progress(0.9, desc="Launching training process")
848
+
849
  # Start the training process
850
  process = subprocess.Popen(
851
  launch_args + config_args,
 
873
  "batch_size": batch_size,
874
  "learning_rate": learning_rate,
875
  "save_iterations": save_iterations,
876
+ "num_gpus": num_gpus,
877
+ "precomputation_items": precomputation_items,
878
+ "lr_warmup_steps": lr_warmup_steps,
879
  "repo_id": repo_id,
880
  "start_time": datetime.now().isoformat()
881
  })
 
900
  self.append_log(success_msg)
901
  logger.info(success_msg)
902
 
903
+ # Final progress update - now we'll track it through the log monitor
904
+ if progress:
905
+ progress(1.0, desc="Training started successfully")
906
+
907
  return success_msg, self.get_logs()
908
 
909
  except Exception as e:
 
1269
  if output:
1270
  # Remove decode() since output is already a string due to universal_newlines=True
1271
  line = output.strip()
1272
+ self.append_log(line)
1273
  if is_error:
 
1274
  #logger.error(line)
1275
+ pass
1276
+
1277
+ # Parse metrics only from stdout
1278
+ metrics = parse_training_log(line)
1279
+ if metrics:
1280
+ status = self.get_status()
1281
+ status.update(metrics)
1282
+ self.save_status(**status)
1283
+
1284
+ # Extract total_steps and current_step for progress tracking
1285
+ if 'step' in metrics:
1286
+ current_step = metrics['step']
1287
+ if 'total_steps' in status:
1288
+ total_steps = status['total_steps']
1289
+
1290
+ # Update progress bar if available and total_steps is known
1291
+ if progress_obj and total_steps > 0:
1292
+ progress_value = min(0.99, current_step / total_steps)
1293
+ progress_obj(progress_value, desc=f"Training: step {current_step}/{total_steps}")
1294
  return True
1295
  return False
1296
 
vms/tabs/train_tab.py CHANGED
@@ -15,7 +15,13 @@ from ..config import (
15
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
16
  DEFAULT_LEARNING_RATE,
17
  DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
18
- DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
 
 
 
 
 
 
19
  )
20
 
21
  logger = logging.getLogger(__name__)
@@ -106,7 +112,30 @@ class TrainTab(BaseTab):
106
  precision=0,
107
  info="Model will be saved periodically after these many steps"
108
  )
109
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  with gr.Column():
111
  with gr.Row():
112
  # Check for existing checkpoints to determine button text
@@ -218,7 +247,27 @@ class TrainTab(BaseTab):
218
  self.components["lora_params_row"]
219
  ]
220
  )
221
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  # Training parameters change events
223
  self.components["lora_rank"].change(
224
  fn=lambda v: self.app.update_ui_state(lora_rank=v),
@@ -274,7 +323,10 @@ class TrainTab(BaseTab):
274
  self.components["learning_rate"],
275
  self.components["save_iterations"],
276
  self.components["preset_info"],
277
- self.components["lora_params_row"]
 
 
 
278
  ]
279
  )
280
 
@@ -332,7 +384,7 @@ class TrainTab(BaseTab):
332
  outputs=[self.components["status_box"]]
333
  )
334
 
335
- def handle_training_start(self, preset, model_type, training_type, *args):
336
  """Handle training start with proper log parser reset and checkpoint detection"""
337
  # Safely reset log parser if it exists
338
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
@@ -341,6 +393,9 @@ class TrainTab(BaseTab):
341
  logger.warning("Log parser not initialized, creating a new one")
342
  from ..utils import TrainingLogParser
343
  self.app.log_parser = TrainingLogParser()
 
 
 
344
 
345
  # Check for latest checkpoint
346
  checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
@@ -351,6 +406,9 @@ class TrainTab(BaseTab):
351
  latest_checkpoint = max(checkpoints, key=os.path.getmtime)
352
  resume_from = str(latest_checkpoint)
353
  logger.info(f"Found checkpoint at {resume_from}, will resume training")
 
 
 
354
 
355
  # Convert model_type display name to internal name
356
  model_internal_type = MODEL_TYPES.get(model_type)
@@ -366,19 +424,32 @@ class TrainTab(BaseTab):
366
  logger.error(f"Invalid training type: {training_type}")
367
  return f"Error: Invalid training type '{training_type}'", "Training type not recognized"
368
 
 
 
 
369
  # Start training (it will automatically use the checkpoint if provided)
370
  try:
371
  return self.app.trainer.start_training(
372
- model_internal_type, # Use internal model type
373
- *args,
 
 
 
 
 
 
374
  preset_name=preset,
375
- training_type=training_internal_type, # Pass the internal training type
376
- resume_from_checkpoint=resume_from
 
 
 
 
377
  )
378
  except Exception as e:
379
  logger.exception("Error starting training")
380
  return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
381
-
382
  def get_model_info(self, model_type: str, training_type: str) -> str:
383
  """Get information about the selected model type and training method"""
384
  if model_type == "HunyuanVideo":
@@ -518,6 +589,9 @@ class TrainTab(BaseTab):
518
  batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", DEFAULT_BATCH_SIZE) else preset.get("batch_size", DEFAULT_BATCH_SIZE)
519
  learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", DEFAULT_LEARNING_RATE) else preset.get("learning_rate", DEFAULT_LEARNING_RATE)
520
  save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS) else preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)
 
 
 
521
 
522
  # Return values in the same order as the output components
523
  return (
@@ -530,7 +604,10 @@ class TrainTab(BaseTab):
530
  learning_rate_val,
531
  save_iterations_val,
532
  info_text,
533
- gr.Row(visible=show_lora_params)
 
 
 
534
  )
535
 
536
  def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
 
15
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
16
  DEFAULT_LEARNING_RATE,
17
  DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
18
+ DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR,
19
+ DEFAULT_SEED,
20
+ DEFAULT_NUM_GPUS,
21
+ DEFAULT_MAX_GPUS,
22
+ DEFAULT_PRECOMPUTATION_ITEMS,
23
+ DEFAULT_NB_TRAINING_STEPS,
24
+ DEFAULT_NB_LR_WARMUP_STEPS,
25
  )
26
 
27
  logger = logging.getLogger(__name__)
 
112
  precision=0,
113
  info="Model will be saved periodically after these many steps"
114
  )
115
+ with gr.Row():
116
+ self.components["num_gpus"] = gr.Slider(
117
+ label="Number of GPUs to use",
118
+ value=DEFAULT_NUM_GPUS,
119
+ minimum=1,
120
+ maximum=DEFAULT_MAX_GPUS,
121
+ step=1,
122
+ info="Number of GPUs to use for training"
123
+ )
124
+ self.components["precomputation_items"] = gr.Number(
125
+ label="Precomputation Items",
126
+ value=DEFAULT_PRECOMPUTATION_ITEMS,
127
+ minimum=1,
128
+ precision=0,
129
+ info="Should be more or less the number of total items (ex: 200 videos), divided by the number of GPUs"
130
+ )
131
+ with gr.Row():
132
+ self.components["lr_warmup_steps"] = gr.Number(
133
+ label="Learning Rate Warmup Steps",
134
+ value=DEFAULT_NB_LR_WARMUP_STEPS,
135
+ minimum=0,
136
+ precision=0,
137
+ info="Number of warmup steps (typically 20-40% of total training steps)"
138
+ )
139
  with gr.Column():
140
  with gr.Row():
141
  # Check for existing checkpoints to determine button text
 
247
  self.components["lora_params_row"]
248
  ]
249
  )
250
+
251
+
252
+ # Add in the connect_events() method:
253
+ self.components["num_gpus"].change(
254
+ fn=lambda v: self.app.update_ui_state(num_gpus=v),
255
+ inputs=[self.components["num_gpus"]],
256
+ outputs=[]
257
+ )
258
+
259
+ self.components["precomputation_items"].change(
260
+ fn=lambda v: self.app.update_ui_state(precomputation_items=v),
261
+ inputs=[self.components["precomputation_items"]],
262
+ outputs=[]
263
+ )
264
+
265
+ self.components["lr_warmup_steps"].change(
266
+ fn=lambda v: self.app.update_ui_state(lr_warmup_steps=v),
267
+ inputs=[self.components["lr_warmup_steps"]],
268
+ outputs=[]
269
+ )
270
+
271
  # Training parameters change events
272
  self.components["lora_rank"].change(
273
  fn=lambda v: self.app.update_ui_state(lora_rank=v),
 
323
  self.components["learning_rate"],
324
  self.components["save_iterations"],
325
  self.components["preset_info"],
326
+ self.components["lora_params_row"],
327
+ self.components["num_gpus"],
328
+ self.components["precomputation_items"],
329
+ self.components["lr_warmup_steps"]
330
  ]
331
  )
332
 
 
384
  outputs=[self.components["status_box"]]
385
  )
386
 
387
+ def handle_training_start(self, preset, model_type, training_type, *args, progress=gr.Progress()):
388
  """Handle training start with proper log parser reset and checkpoint detection"""
389
  # Safely reset log parser if it exists
390
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
 
393
  logger.warning("Log parser not initialized, creating a new one")
394
  from ..utils import TrainingLogParser
395
  self.app.log_parser = TrainingLogParser()
396
+
397
+ # Initialize progress
398
+ progress(0, desc="Initializing training")
399
 
400
  # Check for latest checkpoint
401
  checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
 
406
  latest_checkpoint = max(checkpoints, key=os.path.getmtime)
407
  resume_from = str(latest_checkpoint)
408
  logger.info(f"Found checkpoint at {resume_from}, will resume training")
409
+ progress(0.05, desc=f"Resuming from checkpoint {Path(resume_from).name}")
410
+ else:
411
+ progress(0.05, desc="Starting new training run")
412
 
413
  # Convert model_type display name to internal name
414
  model_internal_type = MODEL_TYPES.get(model_type)
 
424
  logger.error(f"Invalid training type: {training_type}")
425
  return f"Error: Invalid training type '{training_type}'", "Training type not recognized"
426
 
427
+ # Progress update
428
+ progress(0.1, desc="Preparing dataset")
429
+
430
  # Start training (it will automatically use the checkpoint if provided)
431
  try:
432
  return self.app.trainer.start_training(
433
+ model_internal_type,
434
+ lora_rank,
435
+ lora_alpha,
436
+ train_steps,
437
+ batch_size,
438
+ learning_rate,
439
+ save_iterations,
440
+ repo_id,
441
  preset_name=preset,
442
+ training_type=training_internal_type,
443
+ resume_from_checkpoint=resume_from,
444
+ num_gpus=num_gpus,
445
+ precomputation_items=precomputation_items,
446
+ lr_warmup_steps=lr_warmup_steps,
447
+ progress=progress
448
  )
449
  except Exception as e:
450
  logger.exception("Error starting training")
451
  return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
452
+
453
  def get_model_info(self, model_type: str, training_type: str) -> str:
454
  """Get information about the selected model type and training method"""
455
  if model_type == "HunyuanVideo":
 
589
  batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", DEFAULT_BATCH_SIZE) else preset.get("batch_size", DEFAULT_BATCH_SIZE)
590
  learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", DEFAULT_LEARNING_RATE) else preset.get("learning_rate", DEFAULT_LEARNING_RATE)
591
  save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS) else preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)
592
+ num_gpus_val = current_state.get("num_gpus") if current_state.get("num_gpus") != preset.get("num_gpus", DEFAULT_NUM_GPUS) else preset.get("num_gpus", DEFAULT_NUM_GPUS)
593
+ precomputation_items_val = current_state.get("precomputation_items") if current_state.get("precomputation_items") != preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS) else preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS)
594
+ lr_warmup_steps_val = current_state.get("lr_warmup_steps") if current_state.get("lr_warmup_steps") != preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS) else preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS)
595
 
596
  # Return values in the same order as the output components
597
  return (
 
604
  learning_rate_val,
605
  save_iterations_val,
606
  info_text,
607
+ gr.Row(visible=show_lora_params),
608
+ num_gpus_val,
609
+ precomputation_items_val,
610
+ lr_warmup_steps_val
611
  )
612
 
613
  def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
vms/ui/video_trainer_ui.py CHANGED
@@ -14,9 +14,20 @@ from ..config import (
14
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
15
  DEFAULT_LEARNING_RATE,
16
  DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
17
- DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
19
- from ..utils import count_media_files, format_media_title, TrainingLogParser
20
  from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
21
 
22
  logger = logging.getLogger(__name__)
@@ -101,7 +112,10 @@ class VideoTrainerUI:
101
  self.tabs["train_tab"].components["batch_size"],
102
  self.tabs["train_tab"].components["learning_rate"],
103
  self.tabs["train_tab"].components["save_iterations"],
104
- self.tabs["train_tab"].components["current_task_box"] # Add new component
 
 
 
105
  ]
106
  )
107
 
@@ -273,11 +287,26 @@ class VideoTrainerUI:
273
  # Rest of the function remains unchanged
274
  lora_rank_val = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
275
  lora_alpha_val = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
276
- train_steps_val = int(ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS))
277
  batch_size_val = int(ui_state.get("batch_size", DEFAULT_BATCH_SIZE))
278
  learning_rate_val = float(ui_state.get("learning_rate", DEFAULT_LEARNING_RATE))
279
  save_iterations_val = int(ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS))
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  # Initial current task value
282
  current_task_val = ""
283
  if hasattr(self, 'log_parser') and self.log_parser:
@@ -299,7 +328,10 @@ class VideoTrainerUI:
299
  batch_size_val,
300
  learning_rate_val,
301
  save_iterations_val,
302
- current_task_val # Add current task value
 
 
 
303
  )
304
 
305
  def initialize_ui_from_state(self):
 
14
  DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
15
  DEFAULT_LEARNING_RATE,
16
  DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
17
+ DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR,
18
+ DEFAULT_SEED,
19
+ DEFAULT_NUM_GPUS,
20
+ DEFAULT_MAX_GPUS,
21
+ DEFAULT_PRECOMPUTATION_ITEMS,
22
+ DEFAULT_NB_TRAINING_STEPS,
23
+ DEFAULT_NB_LR_WARMUP_STEPS
24
+ )
25
+ from ..utils import (
26
+ get_recommended_precomputation_items,
27
+ count_media_files,
28
+ format_media_title,
29
+ TrainingLogParser
30
  )
 
31
  from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
32
 
33
  logger = logging.getLogger(__name__)
 
112
  self.tabs["train_tab"].components["batch_size"],
113
  self.tabs["train_tab"].components["learning_rate"],
114
  self.tabs["train_tab"].components["save_iterations"],
115
+ self.tabs["train_tab"].components["current_task_box"],
116
+ self.tabs["train_tab"].components["num_gpus"],
117
+ self.tabs["train_tab"].components["precomputation_items"],
118
+ self.tabs["train_tab"].components["lr_warmup_steps"]
119
  ]
120
  )
121
 
 
287
  # Rest of the function remains unchanged
288
  lora_rank_val = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
289
  lora_alpha_val = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
 
290
  batch_size_val = int(ui_state.get("batch_size", DEFAULT_BATCH_SIZE))
291
  learning_rate_val = float(ui_state.get("learning_rate", DEFAULT_LEARNING_RATE))
292
  save_iterations_val = int(ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS))
293
 
294
+ # Update for new UI components
295
+ num_gpus_val = int(ui_state.get("num_gpus", DEFAULT_NUM_GPUS))
296
+
297
+ # Calculate recommended precomputation items based on video count
298
+ video_count = len(list(TRAINING_VIDEOS_PATH.glob('*.mp4')))
299
+ recommended_precomputation = get_recommended_precomputation_items(video_count, num_gpus_val)
300
+ precomputation_items_val = int(ui_state.get("precomputation_items", recommended_precomputation))
301
+
302
+ # Ensure warmup steps are not more than training steps
303
+ train_steps_val = int(ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS))
304
+ default_warmup = min(DEFAULT_NB_LR_WARMUP_STEPS, int(train_steps_val * 0.2))
305
+ lr_warmup_steps_val = int(ui_state.get("lr_warmup_steps", default_warmup))
306
+
307
+ # Ensure warmup steps <= training steps
308
+ lr_warmup_steps_val = min(lr_warmup_steps_val, train_steps_val)
309
+
310
  # Initial current task value
311
  current_task_val = ""
312
  if hasattr(self, 'log_parser') and self.log_parser:
 
328
  batch_size_val,
329
  learning_rate_val,
330
  save_iterations_val,
331
+ current_task_val,
332
+ num_gpus_val,
333
+ precomputation_items_val,
334
+ lr_warmup_steps_val
335
  )
336
 
337
  def initialize_ui_from_state(self):
vms/utils/__init__.py CHANGED
@@ -8,6 +8,8 @@ from .finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_trai
8
 
9
  from . import webdataset_handler
10
 
 
 
11
  __all__ = [
12
  'validate_model_repo',
13
  'make_archive',
@@ -33,5 +35,9 @@ __all__ = [
33
  'prepare_finetrainers_dataset',
34
  'copy_files_to_training_dir',
35
 
36
- 'webdataset_handler'
 
 
 
 
37
  ]
 
8
 
9
  from . import webdataset_handler
10
 
11
+ from .gpu_detector import get_available_gpu_count, get_gpu_info, get_recommended_precomputation_items
12
+
13
  __all__ = [
14
  'validate_model_repo',
15
  'make_archive',
 
35
  'prepare_finetrainers_dataset',
36
  'copy_files_to_training_dir',
37
 
38
+ 'webdataset_handler',
39
+
40
+ 'get_available_gpu_count',
41
+ 'get_gpu_info',
42
+ 'get_recommended_precomputation_items'
43
  ]
vms/utils/finetrainers_utils.py CHANGED
@@ -4,15 +4,22 @@ import logging
4
  import shutil
5
  from typing import Any, Optional, Dict, List, Union, Tuple
6
 
7
- from ..config import STORAGE_PATH, TRAINING_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
 
 
 
 
 
 
 
8
  from .utils import get_video_fps, extract_scene_info, make_archive, is_image_file, is_video_file
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
  def prepare_finetrainers_dataset() -> Tuple[Path, Path]:
13
- """make sure we have a Finetrainers-compatible dataset structure
14
 
15
- Checks that we have:
16
  training/
17
  ├── prompt.txt # All captions, one per line
18
  ├── videos.txt # All video paths, one per line
@@ -30,14 +37,15 @@ def prepare_finetrainers_dataset() -> Tuple[Path, Path]:
30
  # Clear existing training lists
31
  for f in TRAINING_PATH.glob("*"):
32
  if f.is_file():
33
- if f.name in ["videos.txt", "prompts.txt"]:
34
  f.unlink()
35
 
36
  videos_file = TRAINING_PATH / "videos.txt"
37
- prompts_file = TRAINING_PATH / "prompts.txt" # Note: Changed from prompt.txt to prompts.txt to match our config
38
 
39
  media_files = []
40
  captions = []
 
41
  # Process all video files from the videos subdirectory
42
  for idx, file in enumerate(sorted(TRAINING_VIDEOS_PATH.glob("*.mp4"))):
43
  caption_file = file.with_suffix('.txt')
@@ -50,19 +58,16 @@ def prepare_finetrainers_dataset() -> Tuple[Path, Path]:
50
  relative_path = f"videos/{file.name}"
51
  media_files.append(relative_path)
52
  captions.append(caption)
53
-
54
- # Clean up the caption file since it's now in prompts.txt
55
- # EDIT well you know what, let's keep it, otherwise running the function
56
- # twice might cause some errors
57
- # caption_file.unlink()
58
 
59
  # Write files if we have content
60
  if media_files and captions:
61
  videos_file.write_text('\n'.join(media_files))
62
  prompts_file.write_text('\n'.join(captions))
63
-
64
  else:
65
- raise ValueError("No valid video/caption pairs found in training directory")
 
 
66
  # Verify file contents
67
  with open(videos_file) as vf:
68
  video_lines = [l.strip() for l in vf.readlines() if l.strip()]
@@ -70,7 +75,8 @@ def prepare_finetrainers_dataset() -> Tuple[Path, Path]:
70
  prompt_lines = [l.strip() for l in pf.readlines() if l.strip()]
71
 
72
  if len(video_lines) != len(prompt_lines):
73
- raise ValueError(f"Mismatch in generated files: {len(video_lines)} videos vs {len(prompt_lines)} prompts")
 
74
 
75
  return videos_file, prompts_file
76
 
@@ -137,3 +143,67 @@ def copy_files_to_training_dir(prompt_prefix: str) -> int:
137
  gr.Info(f"Successfully generated the training dataset ({nb_copied_pairs} pairs)")
138
 
139
  return nb_copied_pairs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import shutil
5
  from typing import Any, Optional, Dict, List, Union, Tuple
6
 
7
+ from ..config import (
8
+ STORAGE_PATH, TRAINING_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES,
9
+ DEFAULT_VALIDATION_NB_STEPS,
10
+ DEFAULT_VALIDATION_HEIGHT,
11
+ DEFAULT_VALIDATION_WIDTH,
12
+ DEFAULT_VALIDATION_NB_FRAMES,
13
+ DEFAULT_VALIDATION_FRAMERATE
14
+ )
15
  from .utils import get_video_fps, extract_scene_info, make_archive, is_image_file, is_video_file
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
  def prepare_finetrainers_dataset() -> Tuple[Path, Path]:
20
+ """Prepare a Finetrainers-compatible dataset structure
21
 
22
+ Creates:
23
  training/
24
  ├── prompt.txt # All captions, one per line
25
  ├── videos.txt # All video paths, one per line
 
37
  # Clear existing training lists
38
  for f in TRAINING_PATH.glob("*"):
39
  if f.is_file():
40
+ if f.name in ["videos.txt", "prompts.txt", "prompt.txt"]:
41
  f.unlink()
42
 
43
  videos_file = TRAINING_PATH / "videos.txt"
44
+ prompts_file = TRAINING_PATH / "prompts.txt" # Finetrainers can use either prompts.txt or prompt.txt
45
 
46
  media_files = []
47
  captions = []
48
+
49
  # Process all video files from the videos subdirectory
50
  for idx, file in enumerate(sorted(TRAINING_VIDEOS_PATH.glob("*.mp4"))):
51
  caption_file = file.with_suffix('.txt')
 
58
  relative_path = f"videos/{file.name}"
59
  media_files.append(relative_path)
60
  captions.append(caption)
 
 
 
 
 
61
 
62
  # Write files if we have content
63
  if media_files and captions:
64
  videos_file.write_text('\n'.join(media_files))
65
  prompts_file.write_text('\n'.join(captions))
66
+ logger.info(f"Created dataset with {len(media_files)} video/caption pairs")
67
  else:
68
+ logger.warning("No valid video/caption pairs found in training directory")
69
+ return None, None
70
+
71
  # Verify file contents
72
  with open(videos_file) as vf:
73
  video_lines = [l.strip() for l in vf.readlines() if l.strip()]
 
75
  prompt_lines = [l.strip() for l in pf.readlines() if l.strip()]
76
 
77
  if len(video_lines) != len(prompt_lines):
78
+ logger.error(f"Mismatch in generated files: {len(video_lines)} videos vs {len(prompt_lines)} prompts")
79
+ return None, None
80
 
81
  return videos_file, prompts_file
82
 
 
143
  gr.Info(f"Successfully generated the training dataset ({nb_copied_pairs} pairs)")
144
 
145
  return nb_copied_pairs
146
+
147
+ # Add this function to finetrainers_utils.py or a suitable place
148
+
149
+ def create_validation_config() -> Optional[Path]:
150
+ """Create a validation configuration JSON file for Finetrainers
151
+
152
+ Creates a validation dataset file with a subset of the training data
153
+
154
+ Returns:
155
+ Path to the validation JSON file, or None if no training files exist
156
+ """
157
+ # Ensure training dataset exists
158
+ if not TRAINING_VIDEOS_PATH.exists() or not any(TRAINING_VIDEOS_PATH.glob("*.mp4")):
159
+ logger.warning("No training videos found for validation")
160
+ return None
161
+
162
+ # Get a subset of the training videos (up to 4) for validation
163
+ training_videos = list(TRAINING_VIDEOS_PATH.glob("*.mp4"))
164
+ validation_videos = training_videos[:min(4, len(training_videos))]
165
+
166
+ if not validation_videos:
167
+ logger.warning("No validation videos selected")
168
+ return None
169
+
170
+ # Create validation data entries
171
+ validation_data = {"data": []}
172
+
173
+ for video_path in validation_videos:
174
+ # Get caption from matching text file
175
+ caption_path = video_path.with_suffix('.txt')
176
+ if not caption_path.exists():
177
+ logger.warning(f"Missing caption for {video_path}, skipping for validation")
178
+ continue
179
+
180
+ caption = caption_path.read_text().strip()
181
+
182
+ # Get video dimensions and properties
183
+ try:
184
+ # Use the most common default resolution and settings
185
+ data_entry = {
186
+ "caption": caption,
187
+ "image_path": "", # No input image for text-to-video
188
+ "video_path": str(video_path),
189
+ "num_inference_steps": DEFAULT_VALIDATION_NB_STEPS,
190
+ "height": DEFAULT_VALIDATION_HEIGHT,
191
+ "width": DEFAULT_VALIDATION_WIDTH,
192
+ "num_frames": DEFAULT_VALIDATION_NB_FRAMES,
193
+ "frame_rate": DEFAULT_VALIDATION_FRAMERATE
194
+ }
195
+ validation_data["data"].append(data_entry)
196
+ except Exception as e:
197
+ logger.warning(f"Error adding validation entry for {video_path}: {e}")
198
+
199
+ if not validation_data["data"]:
200
+ logger.warning("No valid validation entries created")
201
+ return None
202
+
203
+ # Write validation config to file
204
+ validation_file = OUTPUT_PATH / "validation_config.json"
205
+ with open(validation_file, 'w') as f:
206
+ json.dump(validation_data, f, indent=2)
207
+
208
+ logger.info(f"Created validation config with {len(validation_data['data'])} entries")
209
+ return validation_file
vms/utils/gpu_detector.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+ def get_available_gpu_count():
7
+ """Get the number of available GPUs on the system.
8
+
9
+ Returns:
10
+ int: Number of available GPUs, or 0 if no GPUs are available
11
+ """
12
+ try:
13
+ if torch.cuda.is_available():
14
+ return torch.cuda.device_count()
15
+ else:
16
+ return 0
17
+ except Exception as e:
18
+ logger.warning(f"Error detecting GPUs: {e}")
19
+ return 0
20
+
21
+ def get_gpu_info():
22
+ """Get information about available GPUs.
23
+
24
+ Returns:
25
+ list: List of dictionaries with GPU information
26
+ """
27
+ gpu_info = []
28
+ try:
29
+ if torch.cuda.is_available():
30
+ for i in range(torch.cuda.device_count()):
31
+ gpu = {
32
+ 'index': i,
33
+ 'name': torch.cuda.get_device_name(i),
34
+ 'memory_total': torch.cuda.get_device_properties(i).total_memory
35
+ }
36
+ gpu_info.append(gpu)
37
+ except Exception as e:
38
+ logger.warning(f"Error getting GPU details: {e}")
39
+
40
+ return gpu_info
41
+
42
+ def get_recommended_precomputation_items(num_videos, num_gpus):
43
+ """Calculate recommended precomputation items.
44
+
45
+ Args:
46
+ num_videos (int): Number of videos in dataset
47
+ num_gpus (int): Number of GPUs to use
48
+
49
+ Returns:
50
+ int: Recommended precomputation items value
51
+ """
52
+ if num_gpus <= 0:
53
+ num_gpus = 1
54
+
55
+ # Calculate items per GPU, but ensure it's at least 1
56
+ items_per_gpu = max(1, num_videos // num_gpus)
57
+
58
+ # Limit to a maximum of 512
59
+ return min(512, items_per_gpu)