jbilcke-hf HF Staff commited on
Commit
c6546ad
·
1 Parent(s): 38cfbff

cleaning code

Browse files
vms/config.py CHANGED
@@ -58,9 +58,9 @@ JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
58
 
59
  # Expanded model types to include Wan-2.1-T2V
60
  MODEL_TYPES = {
61
- "HunyuanVideo (LoRA)": "hunyuan_video",
62
- "LTX-Video (LoRA)": "ltx_video",
63
- "Wan-2.1-T2V (LoRA)": "wan"
64
  }
65
 
66
  # Training types
@@ -69,6 +69,23 @@ TRAINING_TYPES = {
69
  "Full Finetune": "full-finetune"
70
  }
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # it is best to use resolutions that are powers of 8
74
  # The resolution should be divisible by 32
@@ -87,39 +104,49 @@ MEDIUM_19_9_RATIO_HEIGHT = 512 # 32 * 16
87
  NB_FRAMES_1 = 1 # 1
88
  NB_FRAMES_9 = 8 + 1 # 8 + 1
89
  NB_FRAMES_17 = 8 * 2 + 1 # 16 + 1
90
- NB_FRAMES_32 = 8 * 4 + 1 # 32 + 1
91
- NB_FRAMES_48 = 8 * 6 + 1 # 48 + 1
92
- NB_FRAMES_64 = 8 * 8 + 1 # 64 + 1
93
- NB_FRAMES_80 = 8 * 10 + 1 # 80 + 1
94
- NB_FRAMES_96 = 8 * 12 + 1 # 96 + 1
95
- NB_FRAMES_112 = 8 * 14 + 1 # 112 + 1
96
- NB_FRAMES_128 = 8 * 16 + 1 # 128 + 1
97
- NB_FRAMES_144 = 8 * 18 + 1 # 144 + 1
98
- NB_FRAMES_160 = 8 * 20 + 1 # 160 + 1
99
- NB_FRAMES_176 = 8 * 22 + 1 # 176 + 1
100
- NB_FRAMES_192 = 8 * 24 + 1 # 192 + 1
101
- NB_FRAMES_224 = 8 * 28 + 1 # 224 + 1
102
- NB_FRAMES_256 = 8 * 32 + 1 # 256 + 1
103
  # 256 isn't a lot by the way, especially with 60 FPS videos..
104
  # can we crank it and put more frames in here?
105
 
 
 
 
 
 
 
 
 
 
 
106
  SMALL_TRAINING_BUCKETS = [
107
  (NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
108
  (NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
109
  (NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
110
- (NB_FRAMES_32, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
111
- (NB_FRAMES_48, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
112
- (NB_FRAMES_64, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
113
- (NB_FRAMES_80, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
114
- (NB_FRAMES_96, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
115
- (NB_FRAMES_112, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
116
- (NB_FRAMES_128, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
117
- (NB_FRAMES_144, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
118
- (NB_FRAMES_160, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
119
- (NB_FRAMES_176, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
120
- (NB_FRAMES_192, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
121
- (NB_FRAMES_224, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
122
- (NB_FRAMES_256, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
123
  ]
124
 
125
  MEDIUM_19_9_RATIO_WIDTH = 928 # 32 * 29
@@ -129,19 +156,19 @@ MEDIUM_19_9_RATIO_BUCKETS = [
129
  (NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
130
  (NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
131
  (NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
132
- (NB_FRAMES_32, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
133
- (NB_FRAMES_48, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
134
- (NB_FRAMES_64, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
135
- (NB_FRAMES_80, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
136
- (NB_FRAMES_96, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
137
- (NB_FRAMES_112, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
138
- (NB_FRAMES_128, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
139
- (NB_FRAMES_144, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
140
- (NB_FRAMES_160, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
141
- (NB_FRAMES_176, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
142
- (NB_FRAMES_192, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
143
- (NB_FRAMES_224, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
144
- (NB_FRAMES_256, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
145
  ]
146
 
147
  # Updated training presets to include Wan-2.1-T2V and support both LoRA and full-finetune
@@ -149,24 +176,24 @@ TRAINING_PRESETS = {
149
  "HunyuanVideo (normal)": {
150
  "model_type": "hunyuan_video",
151
  "training_type": "lora",
152
- "lora_rank": "128",
153
- "lora_alpha": "128",
154
- "num_epochs": 70,
155
- "batch_size": 1,
156
  "learning_rate": 2e-5,
157
- "save_iterations": 500,
158
  "training_buckets": SMALL_TRAINING_BUCKETS,
159
  "flow_weighting_scheme": "none"
160
  },
161
  "LTX-Video (normal)": {
162
  "model_type": "ltx_video",
163
  "training_type": "lora",
164
- "lora_rank": "128",
165
- "lora_alpha": "128",
166
- "num_epochs": 70,
167
- "batch_size": 1,
168
- "learning_rate": 3e-5,
169
- "save_iterations": 500,
170
  "training_buckets": SMALL_TRAINING_BUCKETS,
171
  "flow_weighting_scheme": "logit_normal"
172
  },
@@ -174,21 +201,21 @@ TRAINING_PRESETS = {
174
  "model_type": "ltx_video",
175
  "training_type": "lora",
176
  "lora_rank": "256",
177
- "lora_alpha": "128",
178
- "num_epochs": 50,
179
- "batch_size": 1,
180
- "learning_rate": 3e-5,
181
- "save_iterations": 200,
182
  "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
183
  "flow_weighting_scheme": "logit_normal"
184
  },
185
  "LTX-Video (Full Finetune)": {
186
  "model_type": "ltx_video",
187
  "training_type": "full-finetune",
188
- "num_epochs": 30,
189
- "batch_size": 1,
190
- "learning_rate": 1e-5,
191
- "save_iterations": 300,
192
  "training_buckets": SMALL_TRAINING_BUCKETS,
193
  "flow_weighting_scheme": "logit_normal"
194
  },
@@ -197,10 +224,10 @@ TRAINING_PRESETS = {
197
  "training_type": "lora",
198
  "lora_rank": "32",
199
  "lora_alpha": "32",
200
- "num_epochs": 70,
201
- "batch_size": 1,
202
  "learning_rate": 5e-5,
203
- "save_iterations": 500,
204
  "training_buckets": SMALL_TRAINING_BUCKETS,
205
  "flow_weighting_scheme": "logit_normal"
206
  },
@@ -209,10 +236,10 @@ TRAINING_PRESETS = {
209
  "training_type": "lora",
210
  "lora_rank": "64",
211
  "lora_alpha": "64",
212
- "num_epochs": 50,
213
- "batch_size": 1,
214
- "learning_rate": 3e-5,
215
- "save_iterations": 200,
216
  "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
217
  "flow_weighting_scheme": "logit_normal"
218
  }
@@ -244,7 +271,7 @@ class TrainingConfig:
244
  id_token: Optional[str] = None
245
  video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: SMALL_TRAINING_BUCKETS)
246
  video_reshape_mode: str = "center"
247
- caption_dropout_p: float = 0.05
248
  caption_dropout_technique: str = "empty"
249
  precompute_conditions: bool = False
250
 
@@ -257,16 +284,16 @@ class TrainingConfig:
257
 
258
  # Training arguments
259
  training_type: str = "lora"
260
- seed: int = 42
261
  mixed_precision: str = "bf16"
262
  batch_size: int = 1
263
- train_epochs: int = 70
264
- lora_rank: int = 128
265
- lora_alpha: int = 128
266
  target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"])
267
  gradient_accumulation_steps: int = 1
268
  gradient_checkpointing: bool = True
269
- checkpointing_steps: int = 500
270
  checkpointing_limit: Optional[int] = 2
271
  resume_from_checkpoint: Optional[str] = None
272
  enable_slicing: bool = True
@@ -300,15 +327,15 @@ class TrainingConfig:
300
  data_root=data_path,
301
  output_dir=output_path,
302
  batch_size=1,
303
- train_epochs=70,
304
  lr=2e-5,
305
  gradient_checkpointing=True,
306
  id_token="afkx",
307
  gradient_accumulation_steps=1,
308
- lora_rank=128,
309
- lora_alpha=128,
310
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
311
- caption_dropout_p=0.05,
312
  flow_weighting_scheme="none", # Hunyuan specific
313
  training_type="lora"
314
  )
@@ -322,15 +349,15 @@ class TrainingConfig:
322
  data_root=data_path,
323
  output_dir=output_path,
324
  batch_size=1,
325
- train_epochs=40,
326
- lr=3e-5,
327
  gradient_checkpointing=True,
328
  id_token="BW_STYLE",
329
  gradient_accumulation_steps=4,
330
- lora_rank=128,
331
- lora_alpha=128,
332
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
333
- caption_dropout_p=0.05,
334
  flow_weighting_scheme="logit_normal", # LTX specific
335
  training_type="lora"
336
  )
@@ -344,13 +371,13 @@ class TrainingConfig:
344
  data_root=data_path,
345
  output_dir=output_path,
346
  batch_size=1,
347
- train_epochs=30,
348
  lr=1e-5,
349
  gradient_checkpointing=True,
350
  id_token="BW_STYLE",
351
  gradient_accumulation_steps=1,
352
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
353
- caption_dropout_p=0.05,
354
  flow_weighting_scheme="logit_normal", # LTX specific
355
  training_type="full-finetune"
356
  )
@@ -364,7 +391,7 @@ class TrainingConfig:
364
  data_root=data_path,
365
  output_dir=output_path,
366
  batch_size=1,
367
- train_epochs=70,
368
  lr=5e-5,
369
  gradient_checkpointing=True,
370
  id_token=None, # Default is no ID token for Wan
@@ -373,7 +400,7 @@ class TrainingConfig:
373
  lora_alpha=32,
374
  target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
375
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
376
- caption_dropout_p=0.05,
377
  flow_weighting_scheme="logit_normal", # Wan specific
378
  training_type="lora"
379
  )
@@ -428,7 +455,7 @@ class TrainingConfig:
428
  #args.extend(["--mixed_precision", self.mixed_precision])
429
 
430
  args.extend(["--batch_size", str(self.batch_size)])
431
- args.extend(["--train_steps", str(self.train_epochs * 1000)]) # Convert epochs to steps for compatibility
432
 
433
  # LoRA specific arguments
434
  if self.training_type == "lora":
 
58
 
59
  # Expanded model types to include Wan-2.1-T2V
60
  MODEL_TYPES = {
61
+ "HunyuanVideo": "hunyuan_video",
62
+ "LTX-Video": "ltx_video",
63
+ "Wan-2.1-T2V": "wan"
64
  }
65
 
66
  # Training types
 
69
  "Full Finetune": "full-finetune"
70
  }
71
 
72
+ DEFAULT_SEED = 42
73
+
74
+ DEFAULT_NB_TRAINING_STEPS = 1000
75
+
76
+ DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS = 200
77
+
78
+ DEFAULT_LORA_RANK = 128
79
+ DEFAULT_LORA_RANK_STR = str(DEFAULT_LORA_RANK)
80
+
81
+ DEFAULT_LORA_ALPHA = 128
82
+ DEFAULT_LORA_ALPHA_STR = str(DEFAULT_LORA_ALPHA)
83
+
84
+ DEFAULT_CAPTION_DROPOUT_P = 0.05
85
+
86
+ DEFAULT_BATCH_SIZE = 1
87
+
88
+ DEFAULT_LEARNING_RATE = 3e-5
89
 
90
  # it is best to use resolutions that are powers of 8
91
  # The resolution should be divisible by 32
 
104
  NB_FRAMES_1 = 1 # 1
105
  NB_FRAMES_9 = 8 + 1 # 8 + 1
106
  NB_FRAMES_17 = 8 * 2 + 1 # 16 + 1
107
+ NB_FRAMES_33 = 8 * 4 + 1 # 32 + 1
108
+ NB_FRAMES_49 = 8 * 6 + 1 # 48 + 1
109
+ NB_FRAMES_65 = 8 * 8 + 1 # 64 + 1
110
+ NB_FRAMES_81 = 8 * 10 + 1 # 80 + 1
111
+ NB_FRAMES_97 = 8 * 12 + 1 # 96 + 1
112
+ NB_FRAMES_113 = 8 * 14 + 1 # 112 + 1
113
+ NB_FRAMES_129 = 8 * 16 + 1 # 128 + 1
114
+ NB_FRAMES_145 = 8 * 18 + 1 # 144 + 1
115
+ NB_FRAMES_161 = 8 * 20 + 1 # 160 + 1
116
+ NB_FRAMES_177 = 8 * 22 + 1 # 176 + 1
117
+ NB_FRAMES_193 = 8 * 24 + 1 # 192 + 1
118
+ NB_FRAMES_225 = 8 * 28 + 1 # 224 + 1
119
+ NB_FRAMES_257 = 8 * 32 + 1 # 256 + 1
120
  # 256 isn't a lot by the way, especially with 60 FPS videos..
121
  # can we crank it and put more frames in here?
122
 
123
+ NB_FRAMES_273 = 8 * 34 + 1 # 272 + 1
124
+ NB_FRAMES_289 = 8 * 36 + 1 # 288 + 1
125
+ NB_FRAMES_305 = 8 * 38 + 1 # 304 + 1
126
+ NB_FRAMES_321 = 8 * 40 + 1 # 320 + 1
127
+ NB_FRAMES_337 = 8 * 42 + 1 # 336 + 1
128
+ NB_FRAMES_353 = 8 * 44 + 1 # 352 + 1
129
+ NB_FRAMES_369 = 8 * 46 + 1 # 368 + 1
130
+ NB_FRAMES_385 = 8 * 48 + 1 # 384 + 1
131
+ NB_FRAMES_401 = 8 * 50 + 1 # 400 + 1
132
+
133
  SMALL_TRAINING_BUCKETS = [
134
  (NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
135
  (NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
136
  (NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
137
+ (NB_FRAMES_33, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
138
+ (NB_FRAMES_49, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
139
+ (NB_FRAMES_65, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
140
+ (NB_FRAMES_81, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
141
+ (NB_FRAMES_97, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
142
+ (NB_FRAMES_113, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
143
+ (NB_FRAMES_129, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
144
+ (NB_FRAMES_145, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
145
+ (NB_FRAMES_161, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
146
+ (NB_FRAMES_177, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
147
+ (NB_FRAMES_193, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
148
+ (NB_FRAMES_225, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
149
+ (NB_FRAMES_257, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
150
  ]
151
 
152
  MEDIUM_19_9_RATIO_WIDTH = 928 # 32 * 29
 
156
  (NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
157
  (NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
158
  (NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
159
+ (NB_FRAMES_33, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
160
+ (NB_FRAMES_49, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
161
+ (NB_FRAMES_65, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
162
+ (NB_FRAMES_81, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
163
+ (NB_FRAMES_97, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
164
+ (NB_FRAMES_113, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
165
+ (NB_FRAMES_129, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
166
+ (NB_FRAMES_145, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
167
+ (NB_FRAMES_161, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
168
+ (NB_FRAMES_177, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
169
+ (NB_FRAMES_193, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
170
+ (NB_FRAMES_225, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
171
+ (NB_FRAMES_257, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
172
  ]
173
 
174
  # Updated training presets to include Wan-2.1-T2V and support both LoRA and full-finetune
 
176
  "HunyuanVideo (normal)": {
177
  "model_type": "hunyuan_video",
178
  "training_type": "lora",
179
+ "lora_rank": DEFAULT_LORA_RANK_STR,
180
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR,
181
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
182
+ "batch_size": DEFAULT_BATCH_SIZE,
183
  "learning_rate": 2e-5,
184
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
185
  "training_buckets": SMALL_TRAINING_BUCKETS,
186
  "flow_weighting_scheme": "none"
187
  },
188
  "LTX-Video (normal)": {
189
  "model_type": "ltx_video",
190
  "training_type": "lora",
191
+ "lora_rank": DEFAULT_LORA_RANK_STR,
192
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR,
193
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
194
+ "batch_size": DEFAULT_BATCH_SIZE,
195
+ "learning_rate": DEFAULT_LEARNING_RATE,
196
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
197
  "training_buckets": SMALL_TRAINING_BUCKETS,
198
  "flow_weighting_scheme": "logit_normal"
199
  },
 
201
  "model_type": "ltx_video",
202
  "training_type": "lora",
203
  "lora_rank": "256",
204
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR,
205
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
206
+ "batch_size": DEFAULT_BATCH_SIZE,
207
+ "learning_rate": DEFAULT_LEARNING_RATE,
208
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
209
  "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
210
  "flow_weighting_scheme": "logit_normal"
211
  },
212
  "LTX-Video (Full Finetune)": {
213
  "model_type": "ltx_video",
214
  "training_type": "full-finetune",
215
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
216
+ "batch_size": DEFAULT_BATCH_SIZE,
217
+ "learning_rate": DEFAULT_LEARNING_RATE,
218
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
219
  "training_buckets": SMALL_TRAINING_BUCKETS,
220
  "flow_weighting_scheme": "logit_normal"
221
  },
 
224
  "training_type": "lora",
225
  "lora_rank": "32",
226
  "lora_alpha": "32",
227
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
228
+ "batch_size": DEFAULT_BATCH_SIZE,
229
  "learning_rate": 5e-5,
230
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
231
  "training_buckets": SMALL_TRAINING_BUCKETS,
232
  "flow_weighting_scheme": "logit_normal"
233
  },
 
236
  "training_type": "lora",
237
  "lora_rank": "64",
238
  "lora_alpha": "64",
239
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
240
+ "batch_size": DEFAULT_BATCH_SIZE,
241
+ "learning_rate": DEFAULT_LEARNING_RATE,
242
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
243
  "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
244
  "flow_weighting_scheme": "logit_normal"
245
  }
 
271
  id_token: Optional[str] = None
272
  video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: SMALL_TRAINING_BUCKETS)
273
  video_reshape_mode: str = "center"
274
+ caption_dropout_p: float = DEFAULT_CAPTION_DROPOUT_P
275
  caption_dropout_technique: str = "empty"
276
  precompute_conditions: bool = False
277
 
 
284
 
285
  # Training arguments
286
  training_type: str = "lora"
287
+ seed: int = DEFAULT_SEED
288
  mixed_precision: str = "bf16"
289
  batch_size: int = 1
290
+ train_step: int = DEFAULT_NB_TRAINING_STEPS
291
+ lora_rank: int = DEFAULT_LORA_RANK
292
+ lora_alpha: int = DEFAULT_LORA_ALPHA
293
  target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"])
294
  gradient_accumulation_steps: int = 1
295
  gradient_checkpointing: bool = True
296
+ checkpointing_steps: int = DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS
297
  checkpointing_limit: Optional[int] = 2
298
  resume_from_checkpoint: Optional[str] = None
299
  enable_slicing: bool = True
 
327
  data_root=data_path,
328
  output_dir=output_path,
329
  batch_size=1,
330
+ train_steps=DEFAULT_NB_TRAINING_STEPS,
331
  lr=2e-5,
332
  gradient_checkpointing=True,
333
  id_token="afkx",
334
  gradient_accumulation_steps=1,
335
+ lora_rank=DEFAULT_LORA_RANK,
336
+ lora_alpha=DEFAULT_LORA_ALPHA,
337
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
338
+ caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
339
  flow_weighting_scheme="none", # Hunyuan specific
340
  training_type="lora"
341
  )
 
349
  data_root=data_path,
350
  output_dir=output_path,
351
  batch_size=1,
352
+ train_steps=DEFAULT_NB_TRAINING_STEPS,
353
+ lr=DEFAULT_LEARNING_RATE,
354
  gradient_checkpointing=True,
355
  id_token="BW_STYLE",
356
  gradient_accumulation_steps=4,
357
+ lora_rank=DEFAULT_LORA_RANK,
358
+ lora_alpha=DEFAULT_LORA_ALPHA,
359
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
360
+ caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
361
  flow_weighting_scheme="logit_normal", # LTX specific
362
  training_type="lora"
363
  )
 
371
  data_root=data_path,
372
  output_dir=output_path,
373
  batch_size=1,
374
+ train_steps=DEFAULT_NB_TRAINING_STEPS,
375
  lr=1e-5,
376
  gradient_checkpointing=True,
377
  id_token="BW_STYLE",
378
  gradient_accumulation_steps=1,
379
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
380
+ caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
381
  flow_weighting_scheme="logit_normal", # LTX specific
382
  training_type="full-finetune"
383
  )
 
391
  data_root=data_path,
392
  output_dir=output_path,
393
  batch_size=1,
394
+ train_steps=DEFAULT_NB_TRAINING_STEPS,
395
  lr=5e-5,
396
  gradient_checkpointing=True,
397
  id_token=None, # Default is no ID token for Wan
 
400
  lora_alpha=32,
401
  target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
402
  video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
403
+ caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
404
  flow_weighting_scheme="logit_normal", # Wan specific
405
  training_type="lora"
406
  )
 
455
  #args.extend(["--mixed_precision", self.mixed_precision])
456
 
457
  args.extend(["--batch_size", str(self.batch_size)])
458
+ args.extend(["--train_steps", str(self.train_steps)])
459
 
460
  # LoRA specific arguments
461
  if self.training_type == "lora":
vms/services/trainer.py CHANGED
@@ -23,7 +23,12 @@ from huggingface_hub import upload_folder, create_repo
23
  from ..config import (
24
  TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
25
  STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
26
- MODEL_TYPES, TRAINING_TYPES
 
 
 
 
 
27
  )
28
  from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
29
 
@@ -111,18 +116,19 @@ class TrainingService:
111
  except Exception as e:
112
  logger.error(f"Error saving UI state: {str(e)}")
113
 
 
114
  def load_ui_state(self) -> Dict[str, Any]:
115
  """Load saved UI state"""
116
  ui_state_file = OUTPUT_PATH / "ui_state.json"
117
  default_state = {
118
  "model_type": list(MODEL_TYPES.keys())[0],
119
  "training_type": list(TRAINING_TYPES.keys())[0],
120
- "lora_rank": "128",
121
- "lora_alpha": "128",
122
- "num_epochs": 50,
123
- "batch_size": 1,
124
- "learning_rate": 3e-5,
125
- "save_iterations": 200,
126
  "training_preset": list(TRAINING_PRESETS.keys())[0]
127
  }
128
 
@@ -145,9 +151,14 @@ class TrainingService:
145
 
146
  saved_state = json.loads(file_content)
147
 
 
 
 
 
 
148
  # Convert numeric values to appropriate types
149
- if "num_epochs" in saved_state:
150
- saved_state["num_epochs"] = int(saved_state["num_epochs"])
151
  if "batch_size" in saved_state:
152
  saved_state["batch_size"] = int(saved_state["batch_size"])
153
  if "learning_rate" in saved_state:
@@ -158,6 +169,40 @@ class TrainingService:
158
  # Make sure we have all keys (in case structure changed)
159
  merged_state = default_state.copy()
160
  merged_state.update(saved_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  return merged_state
162
  except json.JSONDecodeError as e:
163
  logger.error(f"Error parsing UI state JSON: {str(e)}")
@@ -176,12 +221,12 @@ class TrainingService:
176
  default_state = {
177
  "model_type": list(MODEL_TYPES.keys())[0],
178
  "training_type": list(TRAINING_TYPES.keys())[0],
179
- "lora_rank": "128",
180
- "lora_alpha": "128",
181
- "num_epochs": 50,
182
- "batch_size": 1,
183
- "learning_rate": 3e-5,
184
- "save_iterations": 200,
185
  "training_preset": list(TRAINING_PRESETS.keys())[0]
186
  }
187
  self.save_ui_state(default_state)
@@ -209,12 +254,12 @@ class TrainingService:
209
  default_state = {
210
  "model_type": list(MODEL_TYPES.keys())[0],
211
  "training_type": list(TRAINING_TYPES.keys())[0],
212
- "lora_rank": "128",
213
- "lora_alpha": "128",
214
- "num_epochs": 50,
215
- "batch_size": 1,
216
- "learning_rate": 3e-5,
217
- "save_iterations": 200,
218
  "training_preset": list(TRAINING_PRESETS.keys())[0]
219
  }
220
  self.save_ui_state(default_state)
@@ -361,7 +406,7 @@ class TrainingService:
361
  model_type: str,
362
  lora_rank: str,
363
  lora_alpha: str,
364
- num_epochs: int,
365
  batch_size: int,
366
  learning_rate: float,
367
  save_iterations: int,
@@ -508,7 +553,7 @@ class TrainingService:
508
  return error_msg, "Unsupported model"
509
 
510
  # Update with UI parameters
511
- config.train_epochs = int(num_epochs)
512
  config.batch_size = int(batch_size)
513
  config.lr = float(learning_rate)
514
  config.checkpointing_steps = int(save_iterations)
@@ -530,11 +575,11 @@ class TrainingService:
530
 
531
  # Common settings for both models
532
  config.mixed_precision = "bf16"
533
- config.seed = 42
534
  config.gradient_checkpointing = True
535
  config.enable_slicing = True
536
  config.enable_tiling = True
537
- config.caption_dropout_p = 0.05
538
 
539
  validation_error = self.validate_training_config(config, model_type)
540
  if validation_error:
@@ -626,7 +671,7 @@ class TrainingService:
626
  "training_type": training_type,
627
  "lora_rank": lora_rank,
628
  "lora_alpha": lora_alpha,
629
- "num_epochs": num_epochs,
630
  "batch_size": batch_size,
631
  "learning_rate": learning_rate,
632
  "save_iterations": save_iterations,
@@ -635,14 +680,12 @@ class TrainingService:
635
  })
636
 
637
  # Update initial training status
638
- total_steps = num_epochs * (max(1, video_count) // batch_size)
639
  self.save_status(
640
  state='training',
641
- epoch=0,
642
  step=0,
643
  total_steps=total_steps,
644
  loss=0.0,
645
- total_epochs=num_epochs,
646
  message='Training started',
647
  repo_id=repo_id,
648
  model_type=model_type,
@@ -789,12 +832,12 @@ class TrainingService:
789
  "params": {
790
  "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
791
  "training_type": TRAINING_TYPES.get(ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])),
792
- "lora_rank": ui_state.get("lora_rank", "128"),
793
- "lora_alpha": ui_state.get("lora_alpha", "128"),
794
- "num_epochs": ui_state.get("num_epochs", 70),
795
- "batch_size": ui_state.get("batch_size", 1),
796
- "learning_rate": ui_state.get("learning_rate", 3e-5),
797
- "save_iterations": ui_state.get("save_iterations", 500),
798
  "preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
799
  "repo_id": "" # Default empty repo ID
800
  }
@@ -853,12 +896,12 @@ class TrainingService:
853
  ui_updates.update({
854
  "model_type": model_type_display, # Use the display name for the UI dropdown
855
  "training_type": training_type_display, # Use the display name for training type
856
- "lora_rank": params.get('lora_rank', "128"),
857
- "lora_alpha": params.get('lora_alpha', "128"),
858
- "num_epochs": params.get('num_epochs', 70),
859
- "batch_size": params.get('batch_size', 1),
860
- "learning_rate": params.get('learning_rate', 3e-5),
861
- "save_iterations": params.get('save_iterations', 500),
862
  "training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
863
  })
864
 
@@ -872,12 +915,12 @@ class TrainingService:
872
  # But keep model_type_display for the UI
873
  result = self.start_training(
874
  model_type=model_type_internal,
875
- lora_rank=params.get('lora_rank', "128"),
876
- lora_alpha=params.get('lora_alpha', "128"),
877
- num_epochs=params.get('num_epochs', 70),
878
- batch_size=params.get('batch_size', 1),
879
- learning_rate=params.get('learning_rate', 3e-5),
880
- save_iterations=params.get('save_iterations', 500),
881
  repo_id=params.get('repo_id', ''),
882
  preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
883
  training_type=training_type_internal,
 
23
  from ..config import (
24
  TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
25
  STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
26
+ MODEL_TYPES, TRAINING_TYPES,
27
+ DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
28
+ DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
29
+ DEFAULT_LEARNING_RATE,
30
+ DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
31
+ DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
32
  )
33
  from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
34
 
 
116
  except Exception as e:
117
  logger.error(f"Error saving UI state: {str(e)}")
118
 
119
+ # Additional fix for the load_ui_state method in trainer.py to clean up old values
120
  def load_ui_state(self) -> Dict[str, Any]:
121
  """Load saved UI state"""
122
  ui_state_file = OUTPUT_PATH / "ui_state.json"
123
  default_state = {
124
  "model_type": list(MODEL_TYPES.keys())[0],
125
  "training_type": list(TRAINING_TYPES.keys())[0],
126
+ "lora_rank": DEFAULT_LORA_RANK_STR,
127
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR,
128
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
129
+ "batch_size": DEFAULT_BATCH_SIZE,
130
+ "learning_rate": DEFAULT_LEARNING_RATE,
131
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
132
  "training_preset": list(TRAINING_PRESETS.keys())[0]
133
  }
134
 
 
151
 
152
  saved_state = json.loads(file_content)
153
 
154
+ # Clean up model type if it contains " (LoRA)" suffix
155
+ if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]:
156
+ saved_state["model_type"] = saved_state["model_type"].replace(" (LoRA)", "")
157
+ logger.info(f"Removed (LoRA) suffix from saved model type: {saved_state['model_type']}")
158
+
159
  # Convert numeric values to appropriate types
160
+ if "train_steps" in saved_state:
161
+ saved_state["train_steps"] = int(saved_state["train_steps"])
162
  if "batch_size" in saved_state:
163
  saved_state["batch_size"] = int(saved_state["batch_size"])
164
  if "learning_rate" in saved_state:
 
169
  # Make sure we have all keys (in case structure changed)
170
  merged_state = default_state.copy()
171
  merged_state.update(saved_state)
172
+
173
+ # Validate model_type is in available choices
174
+ if merged_state["model_type"] not in MODEL_TYPES:
175
+ # Try to map from internal name
176
+ model_found = False
177
+ for display_name, internal_name in MODEL_TYPES.items():
178
+ if internal_name == merged_state["model_type"]:
179
+ merged_state["model_type"] = display_name
180
+ model_found = True
181
+ break
182
+ # If still not found, use default
183
+ if not model_found:
184
+ merged_state["model_type"] = default_state["model_type"]
185
+ logger.warning(f"Invalid model type in saved state, using default")
186
+
187
+ # Validate training_type is in available choices
188
+ if merged_state["training_type"] not in TRAINING_TYPES:
189
+ # Try to map from internal name
190
+ training_found = False
191
+ for display_name, internal_name in TRAINING_TYPES.items():
192
+ if internal_name == merged_state["training_type"]:
193
+ merged_state["training_type"] = display_name
194
+ training_found = True
195
+ break
196
+ # If still not found, use default
197
+ if not training_found:
198
+ merged_state["training_type"] = default_state["training_type"]
199
+ logger.warning(f"Invalid training type in saved state, using default")
200
+
201
+ # Validate training_preset is in available choices
202
+ if merged_state["training_preset"] not in TRAINING_PRESETS:
203
+ merged_state["training_preset"] = default_state["training_preset"]
204
+ logger.warning(f"Invalid training preset in saved state, using default")
205
+
206
  return merged_state
207
  except json.JSONDecodeError as e:
208
  logger.error(f"Error parsing UI state JSON: {str(e)}")
 
221
  default_state = {
222
  "model_type": list(MODEL_TYPES.keys())[0],
223
  "training_type": list(TRAINING_TYPES.keys())[0],
224
+ "lora_rank": DEFAULT_LORA_RANK_STR,
225
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR,
226
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
227
+ "batch_size": DEFAULT_BATCH_SIZE,
228
+ "learning_rate": DEFAULT_LEARNING_RATE,
229
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
230
  "training_preset": list(TRAINING_PRESETS.keys())[0]
231
  }
232
  self.save_ui_state(default_state)
 
254
  default_state = {
255
  "model_type": list(MODEL_TYPES.keys())[0],
256
  "training_type": list(TRAINING_TYPES.keys())[0],
257
+ "lora_rank": DEFAULT_LORA_RANK_STR,
258
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR,
259
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
260
+ "batch_size": DEFAULT_BATCH_SIZE,
261
+ "learning_rate": DEFAULT_LEARNING_RATE,
262
+ "save_iterations": DEFAULT_NB_TRAINING_STEPS,
263
  "training_preset": list(TRAINING_PRESETS.keys())[0]
264
  }
265
  self.save_ui_state(default_state)
 
406
  model_type: str,
407
  lora_rank: str,
408
  lora_alpha: str,
409
+ train_steps: int,
410
  batch_size: int,
411
  learning_rate: float,
412
  save_iterations: int,
 
553
  return error_msg, "Unsupported model"
554
 
555
  # Update with UI parameters
556
+ config.train_steps = int(train_steps)
557
  config.batch_size = int(batch_size)
558
  config.lr = float(learning_rate)
559
  config.checkpointing_steps = int(save_iterations)
 
575
 
576
  # Common settings for both models
577
  config.mixed_precision = "bf16"
578
+ config.seed = DEFAULT_SEED
579
  config.gradient_checkpointing = True
580
  config.enable_slicing = True
581
  config.enable_tiling = True
582
+ config.caption_dropout_p = DEFAULT_CAPTION_DROPOUT_P
583
 
584
  validation_error = self.validate_training_config(config, model_type)
585
  if validation_error:
 
671
  "training_type": training_type,
672
  "lora_rank": lora_rank,
673
  "lora_alpha": lora_alpha,
674
+ "train_steps": train_steps,
675
  "batch_size": batch_size,
676
  "learning_rate": learning_rate,
677
  "save_iterations": save_iterations,
 
680
  })
681
 
682
  # Update initial training status
683
+ total_steps = int(train_steps)
684
  self.save_status(
685
  state='training',
 
686
  step=0,
687
  total_steps=total_steps,
688
  loss=0.0,
 
689
  message='Training started',
690
  repo_id=repo_id,
691
  model_type=model_type,
 
832
  "params": {
833
  "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
834
  "training_type": TRAINING_TYPES.get(ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])),
835
+ "lora_rank": ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR),
836
+ "lora_alpha": ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR),
837
+ "train_steps": ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
838
+ "batch_size": ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
839
+ "learning_rate": ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
840
+ "save_iterations": ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
841
  "preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
842
  "repo_id": "" # Default empty repo ID
843
  }
 
896
  ui_updates.update({
897
  "model_type": model_type_display, # Use the display name for the UI dropdown
898
  "training_type": training_type_display, # Use the display name for training type
899
+ "lora_rank": params.get('lora_rank', DEFAULT_LORA_RANK_STR),
900
+ "lora_alpha": params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR),
901
+ "train_steps": params.get('train_steps', DEFAULT_NB_TRAINING_STEPS),
902
+ "batch_size": params.get('batch_size', DEFAULT_BATCH_SIZE),
903
+ "learning_rate": params.get('learning_rate', DEFAULT_LEARNING_RATE),
904
+ "save_iterations": params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
905
  "training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
906
  })
907
 
 
915
  # But keep model_type_display for the UI
916
  result = self.start_training(
917
  model_type=model_type_internal,
918
+ lora_rank=params.get('lora_rank', DEFAULT_LORA_RANK_STR),
919
+ lora_alpha=params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR),
920
+ train_size=params.get('train_steps', DEFAULT_NB_TRAINING_STEPS),
921
+ batch_size=params.get('batch_size', DEFAULT_BATCH_SIZE),
922
+ learning_rate=params.get('learning_rate', DEFAULT_LEARNING_RATE),
923
+ save_iterations=params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
924
  repo_id=params.get('repo_id', ''),
925
  preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
926
  training_type=training_type_internal,
vms/tabs/train_tab.py CHANGED
@@ -9,7 +9,14 @@ from typing import Dict, Any, List, Optional, Tuple
9
  from pathlib import Path
10
 
11
  from .base_tab import BaseTab
12
- from ..config import TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS, TRAINING_TYPES
 
 
 
 
 
 
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
@@ -63,20 +70,20 @@ class TrainTab(BaseTab):
63
  self.components["lora_rank"] = gr.Dropdown(
64
  label="LoRA Rank",
65
  choices=["16", "32", "64", "128", "256", "512", "1024"],
66
- value="128",
67
  type="value"
68
  )
69
  self.components["lora_alpha"] = gr.Dropdown(
70
  label="LoRA Alpha",
71
  choices=["16", "32", "64", "128", "256", "512", "1024"],
72
- value="128",
73
  type="value"
74
  )
75
 
76
  with gr.Row():
77
- self.components["num_epochs"] = gr.Number(
78
- label="Number of Epochs",
79
- value=70,
80
  minimum=1,
81
  precision=0
82
  )
@@ -89,13 +96,13 @@ class TrainTab(BaseTab):
89
  with gr.Row():
90
  self.components["learning_rate"] = gr.Number(
91
  label="Learning Rate",
92
- value=2e-5,
93
- minimum=1e-7
94
  )
95
  self.components["save_iterations"] = gr.Number(
96
  label="Save checkpoint every N iterations",
97
- value=500,
98
- minimum=50,
99
  precision=0,
100
  info="Model will be saved periodically after these many steps"
101
  )
@@ -170,7 +177,7 @@ class TrainTab(BaseTab):
170
 
171
  return {
172
  self.components["model_info"]: info,
173
- self.components["num_epochs"]: params["num_epochs"],
174
  self.components["batch_size"]: params["batch_size"],
175
  self.components["learning_rate"]: params["learning_rate"],
176
  self.components["save_iterations"]: params["save_iterations"],
@@ -186,7 +193,7 @@ class TrainTab(BaseTab):
186
  inputs=[self.components["model_type"], self.components["training_type"]],
187
  outputs=[
188
  self.components["model_info"],
189
- self.components["num_epochs"],
190
  self.components["batch_size"],
191
  self.components["learning_rate"],
192
  self.components["save_iterations"],
@@ -204,7 +211,7 @@ class TrainTab(BaseTab):
204
  inputs=[self.components["model_type"], self.components["training_type"]],
205
  outputs=[
206
  self.components["model_info"],
207
- self.components["num_epochs"],
208
  self.components["batch_size"],
209
  self.components["learning_rate"],
210
  self.components["save_iterations"],
@@ -225,9 +232,9 @@ class TrainTab(BaseTab):
225
  outputs=[]
226
  )
227
 
228
- self.components["num_epochs"].change(
229
- fn=lambda v: self.app.update_ui_state(num_epochs=v),
230
- inputs=[self.components["num_epochs"]],
231
  outputs=[]
232
  )
233
 
@@ -262,7 +269,7 @@ class TrainTab(BaseTab):
262
  self.components["training_type"],
263
  self.components["lora_rank"],
264
  self.components["lora_alpha"],
265
- self.components["num_epochs"],
266
  self.components["batch_size"],
267
  self.components["learning_rate"],
268
  self.components["save_iterations"],
@@ -280,7 +287,7 @@ class TrainTab(BaseTab):
280
  self.components["training_type"],
281
  self.components["lora_rank"],
282
  self.components["lora_alpha"],
283
- self.components["num_epochs"],
284
  self.components["batch_size"],
285
  self.components["learning_rate"],
286
  self.components["save_iterations"],
@@ -290,27 +297,20 @@ class TrainTab(BaseTab):
290
  self.components["status_box"],
291
  self.components["log_box"]
292
  ]
293
- ).success(
294
- fn=self.get_latest_status_message_logs_and_button_labels,
295
- outputs=[
296
- self.components["status_box"],
297
- self.components["log_box"],
298
- self.components["start_btn"],
299
- self.components["stop_btn"],
300
- self.components["pause_resume_btn"],
301
- self.components["current_task_box"] # Include new component
302
- ]
303
  )
304
 
 
 
 
305
  self.components["pause_resume_btn"].click(
306
  fn=self.handle_pause_resume,
307
  outputs=[
308
  self.components["status_box"],
309
  self.components["log_box"],
 
310
  self.components["start_btn"],
311
  self.components["stop_btn"],
312
- self.components["pause_resume_btn"],
313
- self.components["current_task_box"] # Include new component
314
  ]
315
  )
316
 
@@ -319,10 +319,10 @@ class TrainTab(BaseTab):
319
  outputs=[
320
  self.components["status_box"],
321
  self.components["log_box"],
 
322
  self.components["start_btn"],
323
  self.components["stop_btn"],
324
- self.components["pause_resume_btn"],
325
- self.components["current_task_box"] # Include new component
326
  ]
327
  )
328
 
@@ -330,16 +330,6 @@ class TrainTab(BaseTab):
330
  self.components["delete_checkpoints_btn"].click(
331
  fn=lambda: self.app.trainer.delete_all_checkpoints(),
332
  outputs=[self.components["status_box"]]
333
- ).then(
334
- fn=self.get_latest_status_message_logs_and_button_labels,
335
- outputs=[
336
- self.components["status_box"],
337
- self.components["log_box"],
338
- self.components["start_btn"],
339
- self.components["stop_btn"],
340
- self.components["delete_checkpoints_btn"],
341
- self.components["current_task_box"] # Include new component
342
- ]
343
  )
344
 
345
  def handle_training_start(self, preset, model_type, training_type, *args):
@@ -391,7 +381,7 @@ class TrainTab(BaseTab):
391
 
392
  def get_model_info(self, model_type: str, training_type: str) -> str:
393
  """Get information about the selected model type and training method"""
394
- if model_type == "HunyuanVideo (LoRA)":
395
  base_info = """### HunyuanVideo
396
  - Required VRAM: ~48GB minimum
397
  - Recommended batch size: 1-2
@@ -403,7 +393,7 @@ class TrainTab(BaseTab):
403
  else:
404
  return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
405
 
406
- elif model_type == "LTX-Video (LoRA)":
407
  base_info = """### LTX-Video
408
  - Recommended batch size: 1-4
409
  - Typical training time: 1-3 hours
@@ -414,14 +404,14 @@ class TrainTab(BaseTab):
414
  else:
415
  return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
416
 
417
- elif model_type == "Wan-2.1-T2V (LoRA)":
418
  base_info = """### Wan-2.1-T2V
419
- - Recommended batch size: 1-2
420
- - Typical training time: 1-3 hours
421
  - Default resolution: 49x512x768"""
422
 
423
  if training_type == "LoRA Finetune":
424
- return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
425
  else:
426
  return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
427
 
@@ -440,51 +430,51 @@ class TrainTab(BaseTab):
440
  # Use the first matching preset
441
  preset = matching_presets[0]
442
  return {
443
- "num_epochs": preset.get("num_epochs", 70),
444
- "batch_size": preset.get("batch_size", 1),
445
- "learning_rate": preset.get("learning_rate", 3e-5),
446
- "save_iterations": preset.get("save_iterations", 500),
447
- "lora_rank": preset.get("lora_rank", "128"),
448
- "lora_alpha": preset.get("lora_alpha", "128")
449
  }
450
 
451
  # Default fallbacks
452
  if model_type == "hunyuan_video":
453
  return {
454
- "num_epochs": 70,
455
- "batch_size": 1,
456
  "learning_rate": 2e-5,
457
- "save_iterations": 500,
458
- "lora_rank": "128",
459
- "lora_alpha": "128"
460
  }
461
  elif model_type == "ltx_video":
462
  return {
463
- "num_epochs": 70,
464
- "batch_size": 1,
465
- "learning_rate": 3e-5,
466
- "save_iterations": 500,
467
- "lora_rank": "128",
468
- "lora_alpha": "128"
469
  }
470
  elif model_type == "wan":
471
  return {
472
- "num_epochs": 70,
473
- "batch_size": 1,
474
  "learning_rate": 5e-5,
475
- "save_iterations": 500,
476
  "lora_rank": "32",
477
  "lora_alpha": "32"
478
  }
479
  else:
480
  # Generic defaults
481
  return {
482
- "num_epochs": 70,
483
- "batch_size": 1,
484
- "learning_rate": 3e-5,
485
- "save_iterations": 500,
486
- "lora_rank": "128",
487
- "lora_alpha": "128"
488
  }
489
 
490
  def update_training_params(self, preset_name: str) -> Tuple:
@@ -522,12 +512,12 @@ class TrainTab(BaseTab):
522
  show_lora_params = preset["training_type"] == "lora"
523
 
524
  # Use preset defaults but preserve user-modified values if they exist
525
- lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset.get("lora_rank", "128")
526
- lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset.get("lora_alpha", "128")
527
- num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset.get("num_epochs", 70)
528
- batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset.get("batch_size", 1)
529
- learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset.get("learning_rate", 3e-5)
530
- save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset.get("save_iterations", 500)
531
 
532
  # Return values in the same order as the output components
533
  return (
@@ -535,7 +525,7 @@ class TrainTab(BaseTab):
535
  training_display_name,
536
  lora_rank_val,
537
  lora_alpha_val,
538
- num_epochs_val,
539
  batch_size_val,
540
  learning_rate_val,
541
  save_iterations_val,
@@ -543,66 +533,6 @@ class TrainTab(BaseTab):
543
  gr.Row(visible=show_lora_params)
544
  )
545
 
546
- def update_training_ui(self, training_state: Dict[str, Any]):
547
- """Update UI components based on training state"""
548
- updates = {}
549
-
550
- # Update status box with high-level information
551
- status_text = []
552
- if training_state["status"] != "idle":
553
- status_text.extend([
554
- f"Status: {training_state['status']}",
555
- f"Progress: {training_state['progress']}",
556
- f"Step: {training_state['current_step']}/{training_state['total_steps']}",
557
- f"Time elapsed: {training_state['elapsed']}",
558
- f"Estimated remaining: {training_state['remaining']}",
559
- "",
560
- f"Current loss: {training_state['step_loss']}",
561
- f"Learning rate: {training_state['learning_rate']}",
562
- f"Gradient norm: {training_state['grad_norm']}",
563
- f"Memory usage: {training_state['memory']}"
564
- ])
565
-
566
- if training_state["error_message"]:
567
- status_text.append(f"\nError: {training_state['error_message']}")
568
-
569
- updates["status_box"] = "\n".join(status_text)
570
-
571
- # Add current task information to the dedicated box
572
- if training_state.get("current_task"):
573
- updates["current_task_box"] = training_state["current_task"]
574
- else:
575
- updates["current_task_box"] = "No active task" if training_state["status"] != "training" else "Waiting for task information..."
576
-
577
- # Update button states
578
- updates["start_btn"] = gr.Button(
579
- "Start training",
580
- interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]),
581
- variant="primary" if training_state["status"] == "idle" else "secondary"
582
- )
583
-
584
- updates["stop_btn"] = gr.Button(
585
- "Stop training",
586
- interactive=(training_state["status"] in ["training", "initializing"]),
587
- variant="stop"
588
- )
589
-
590
- return updates
591
-
592
- def handle_pause_resume(self):
593
- status, _, _ = self.get_latest_status_message_and_logs()
594
-
595
- if status == "paused":
596
- self.app.trainer.resume_training()
597
- else:
598
- self.app.trainer.pause_training()
599
-
600
- return self.get_latest_status_message_logs_and_button_labels()
601
-
602
- def handle_stop(self):
603
- self.app.trainer.stop_training()
604
- return self.get_latest_status_message_logs_and_button_labels()
605
-
606
  def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
607
  """Get latest status message, log content, and status code in a safer way"""
608
  state = self.app.trainer.get_status()
@@ -663,61 +593,107 @@ class TrainTab(BaseTab):
663
 
664
  return (state["status"], state["message"], logs)
665
 
666
- def get_latest_status_message_logs_and_button_labels(self) -> Tuple:
667
- """Get latest status message, logs and button states"""
668
  status, message, logs = self.get_latest_status_message_and_logs()
669
 
670
- # Add checkpoints detection
671
- has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
672
-
673
- button_updates = self.update_training_buttons(status, has_checkpoints).values()
674
-
675
  # Get current task if available
676
  current_task = ""
677
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
678
  current_task = self.app.log_parser.get_current_task_display()
679
 
680
- # Return in order expected by timer (added current_task)
681
- return (message, logs, *button_updates, current_task)
682
-
683
- def update_training_buttons(self, status: str, has_checkpoints: bool = None) -> Dict:
684
- """Update training control buttons based on state"""
685
- if has_checkpoints is None:
686
- has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
687
-
 
688
  is_training = status in ["training", "initializing"]
689
  is_completed = status in ["completed", "error", "stopped"]
690
 
691
  start_text = "Continue Training" if has_checkpoints else "Start Training"
692
 
693
- # Only include buttons that we know exist in components
694
- result = {
695
- "start_btn": gr.Button(
696
- value=start_text,
697
- interactive=not is_training,
698
- variant="primary" if not is_training else "secondary",
699
- ),
700
- "stop_btn": gr.Button(
701
- value="Stop at Last Checkpoint",
702
- interactive=is_training,
703
- variant="primary" if is_training else "secondary",
704
- )
705
- }
706
 
707
- # Add delete_checkpoints_btn only if it exists in components
 
 
 
 
 
 
708
  if "delete_checkpoints_btn" in self.components:
709
- result["delete_checkpoints_btn"] = gr.Button(
710
- value="Delete All Checkpoints",
711
  interactive=has_checkpoints and not is_training,
712
- variant="stop",
713
  )
714
  else:
715
- # Add pause_resume_btn as fallback
716
- result["pause_resume_btn"] = gr.Button(
717
- value="Resume Training" if status == "paused" else "Pause Training",
718
  interactive=(is_training or status == "paused") and not is_completed,
719
  variant="secondary",
720
  visible=False
721
  )
722
 
723
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from pathlib import Path
10
 
11
  from .base_tab import BaseTab
12
+ from ..config import (
13
+ TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
14
+ DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
15
+ DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
16
+ DEFAULT_LEARNING_RATE,
17
+ DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
18
+ DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
19
+ )
20
 
21
  logger = logging.getLogger(__name__)
22
 
 
70
  self.components["lora_rank"] = gr.Dropdown(
71
  label="LoRA Rank",
72
  choices=["16", "32", "64", "128", "256", "512", "1024"],
73
+ value=DEFAULT_LORA_RANK_STR,
74
  type="value"
75
  )
76
  self.components["lora_alpha"] = gr.Dropdown(
77
  label="LoRA Alpha",
78
  choices=["16", "32", "64", "128", "256", "512", "1024"],
79
+ value=DEFAULT_LORA_ALPHA_STR,
80
  type="value"
81
  )
82
 
83
  with gr.Row():
84
+ self.components["train_steps"] = gr.Number(
85
+ label="Number of Training Steps",
86
+ value=DEFAULT_NB_TRAINING_STEPS,
87
  minimum=1,
88
  precision=0
89
  )
 
96
  with gr.Row():
97
  self.components["learning_rate"] = gr.Number(
98
  label="Learning Rate",
99
+ value=DEFAULT_LEARNING_RATE,
100
+ minimum=1e-8
101
  )
102
  self.components["save_iterations"] = gr.Number(
103
  label="Save checkpoint every N iterations",
104
+ value=DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
105
+ minimum=1,
106
  precision=0,
107
  info="Model will be saved periodically after these many steps"
108
  )
 
177
 
178
  return {
179
  self.components["model_info"]: info,
180
+ self.components["train_steps"]: params["train_steps"],
181
  self.components["batch_size"]: params["batch_size"],
182
  self.components["learning_rate"]: params["learning_rate"],
183
  self.components["save_iterations"]: params["save_iterations"],
 
193
  inputs=[self.components["model_type"], self.components["training_type"]],
194
  outputs=[
195
  self.components["model_info"],
196
+ self.components["train_steps"],
197
  self.components["batch_size"],
198
  self.components["learning_rate"],
199
  self.components["save_iterations"],
 
211
  inputs=[self.components["model_type"], self.components["training_type"]],
212
  outputs=[
213
  self.components["model_info"],
214
+ self.components["train_steps"],
215
  self.components["batch_size"],
216
  self.components["learning_rate"],
217
  self.components["save_iterations"],
 
232
  outputs=[]
233
  )
234
 
235
+ self.components["train_steps"].change(
236
+ fn=lambda v: self.app.update_ui_state(train_steps=v),
237
+ inputs=[self.components["train_steps"]],
238
  outputs=[]
239
  )
240
 
 
269
  self.components["training_type"],
270
  self.components["lora_rank"],
271
  self.components["lora_alpha"],
272
+ self.components["train_steps"],
273
  self.components["batch_size"],
274
  self.components["learning_rate"],
275
  self.components["save_iterations"],
 
287
  self.components["training_type"],
288
  self.components["lora_rank"],
289
  self.components["lora_alpha"],
290
+ self.components["train_steps"],
291
  self.components["batch_size"],
292
  self.components["learning_rate"],
293
  self.components["save_iterations"],
 
297
  self.components["status_box"],
298
  self.components["log_box"]
299
  ]
 
 
 
 
 
 
 
 
 
 
300
  )
301
 
302
+ # Use simplified event handlers for pause/resume and stop
303
+ third_btn = self.components["delete_checkpoints_btn"] if "delete_checkpoints_btn" in self.components else self.components["pause_resume_btn"]
304
+
305
  self.components["pause_resume_btn"].click(
306
  fn=self.handle_pause_resume,
307
  outputs=[
308
  self.components["status_box"],
309
  self.components["log_box"],
310
+ self.components["current_task_box"],
311
  self.components["start_btn"],
312
  self.components["stop_btn"],
313
+ third_btn
 
314
  ]
315
  )
316
 
 
319
  outputs=[
320
  self.components["status_box"],
321
  self.components["log_box"],
322
+ self.components["current_task_box"],
323
  self.components["start_btn"],
324
  self.components["stop_btn"],
325
+ third_btn
 
326
  ]
327
  )
328
 
 
330
  self.components["delete_checkpoints_btn"].click(
331
  fn=lambda: self.app.trainer.delete_all_checkpoints(),
332
  outputs=[self.components["status_box"]]
 
 
 
 
 
 
 
 
 
 
333
  )
334
 
335
  def handle_training_start(self, preset, model_type, training_type, *args):
 
381
 
382
  def get_model_info(self, model_type: str, training_type: str) -> str:
383
  """Get information about the selected model type and training method"""
384
+ if model_type == "HunyuanVideo":
385
  base_info = """### HunyuanVideo
386
  - Required VRAM: ~48GB minimum
387
  - Recommended batch size: 1-2
 
393
  else:
394
  return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
395
 
396
+ elif model_type == "LTX-Video":
397
  base_info = """### LTX-Video
398
  - Recommended batch size: 1-4
399
  - Typical training time: 1-3 hours
 
404
  else:
405
  return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
406
 
407
+ elif model_type == "Wan-2.1-T2V":
408
  base_info = """### Wan-2.1-T2V
409
+ - Recommended batch size: ?
410
+ - Typical training time: ? hours
411
  - Default resolution: 49x512x768"""
412
 
413
  if training_type == "LoRA Finetune":
414
+ return base_info + "\n- Required VRAM: ?GB minimum\n- Default LoRA rank: 32 (~120 MB)"
415
  else:
416
  return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
417
 
 
430
  # Use the first matching preset
431
  preset = matching_presets[0]
432
  return {
433
+ "train_steps": preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
434
+ "batch_size": preset.get("batch_size", DEFAULT_BATCH_SIZE),
435
+ "learning_rate": preset.get("learning_rate", DEFAULT_LEARNING_RATE),
436
+ "save_iterations": preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
437
+ "lora_rank": preset.get("lora_rank", DEFAULT_LORA_RANK_STR),
438
+ "lora_alpha": preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
439
  }
440
 
441
  # Default fallbacks
442
  if model_type == "hunyuan_video":
443
  return {
444
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
445
+ "batch_size": DEFAULT_BATCH_SIZE,
446
  "learning_rate": 2e-5,
447
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
448
+ "lora_rank": DEFAULT_LORA_RANK_STR,
449
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR
450
  }
451
  elif model_type == "ltx_video":
452
  return {
453
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
454
+ "batch_size": DEFAULT_BATCH_SIZE,
455
+ "learning_rate": DEFAULT_LEARNING_RATE,
456
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
457
+ "lora_rank": DEFAULT_LORA_RANK_STR,
458
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR
459
  }
460
  elif model_type == "wan":
461
  return {
462
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
463
+ "batch_size": DEFAULT_BATCH_SIZE,
464
  "learning_rate": 5e-5,
465
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
466
  "lora_rank": "32",
467
  "lora_alpha": "32"
468
  }
469
  else:
470
  # Generic defaults
471
  return {
472
+ "train_steps": DEFAULT_NB_TRAINING_STEPS,
473
+ "batch_size": DEFAULT_BATCH_SIZE,
474
+ "learning_rate": DEFAULT_LEARNING_RATE,
475
+ "save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
476
+ "lora_rank": DEFAULT_LORA_RANK_STR,
477
+ "lora_alpha": DEFAULT_LORA_ALPHA_STR
478
  }
479
 
480
  def update_training_params(self, preset_name: str) -> Tuple:
 
512
  show_lora_params = preset["training_type"] == "lora"
513
 
514
  # Use preset defaults but preserve user-modified values if they exist
515
+ lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", DEFAULT_LORA_RANK_STR) else preset.get("lora_rank", DEFAULT_LORA_RANK_STR)
516
+ lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR) else preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
517
+ train_steps_val = current_state.get("train_steps") if current_state.get("train_steps") != preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS) else preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS)
518
+ batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", DEFAULT_BATCH_SIZE) else preset.get("batch_size", DEFAULT_BATCH_SIZE)
519
+ learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", DEFAULT_LEARNING_RATE) else preset.get("learning_rate", DEFAULT_LEARNING_RATE)
520
+ save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS) else preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)
521
 
522
  # Return values in the same order as the output components
523
  return (
 
525
  training_display_name,
526
  lora_rank_val,
527
  lora_alpha_val,
528
+ train_steps_val,
529
  batch_size_val,
530
  learning_rate_val,
531
  save_iterations_val,
 
533
  gr.Row(visible=show_lora_params)
534
  )
535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
537
  """Get latest status message, log content, and status code in a safer way"""
538
  state = self.app.trainer.get_status()
 
593
 
594
  return (state["status"], state["message"], logs)
595
 
596
+ def get_status_updates(self):
597
+ """Get status updates for text components (no variant property)"""
598
  status, message, logs = self.get_latest_status_message_and_logs()
599
 
 
 
 
 
 
600
  # Get current task if available
601
  current_task = ""
602
  if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
603
  current_task = self.app.log_parser.get_current_task_display()
604
 
605
+ return message, logs, current_task
606
+
607
+ def get_button_updates(self):
608
+ """Get button updates (with variant property)"""
609
+ status, _, _ = self.get_latest_status_message_and_logs()
610
+
611
+ # Add checkpoints detection
612
+ has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
613
+
614
  is_training = status in ["training", "initializing"]
615
  is_completed = status in ["completed", "error", "stopped"]
616
 
617
  start_text = "Continue Training" if has_checkpoints else "Start Training"
618
 
619
+ # Create button updates
620
+ start_btn = gr.Button(
621
+ value=start_text,
622
+ interactive=not is_training,
623
+ variant="primary" if not is_training else "secondary"
624
+ )
 
 
 
 
 
 
 
625
 
626
+ stop_btn = gr.Button(
627
+ value="Stop at Last Checkpoint",
628
+ interactive=is_training,
629
+ variant="primary" if is_training else "secondary"
630
+ )
631
+
632
+ # Add delete_checkpoints_btn or pause_resume_btn
633
  if "delete_checkpoints_btn" in self.components:
634
+ third_btn = gr.Button(
635
+ "Delete All Checkpoints",
636
  interactive=has_checkpoints and not is_training,
637
+ variant="stop"
638
  )
639
  else:
640
+ third_btn = gr.Button(
641
+ "Resume Training" if status == "paused" else "Pause Training",
 
642
  interactive=(is_training or status == "paused") and not is_completed,
643
  variant="secondary",
644
  visible=False
645
  )
646
 
647
+ return start_btn, stop_btn, third_btn
648
+
649
+ def update_training_ui(self, training_state: Dict[str, Any]):
650
+ """Update UI components based on training state"""
651
+ updates = {}
652
+
653
+ # Update status box with high-level information
654
+ status_text = []
655
+ if training_state["status"] != "idle":
656
+ status_text.extend([
657
+ f"Status: {training_state['status']}",
658
+ f"Progress: {training_state['progress']}",
659
+ f"Step: {training_state['current_step']}/{training_state['total_steps']}",
660
+ f"Time elapsed: {training_state['elapsed']}",
661
+ f"Estimated remaining: {training_state['remaining']}",
662
+ "",
663
+ f"Current loss: {training_state['step_loss']}",
664
+ f"Learning rate: {training_state['learning_rate']}",
665
+ f"Gradient norm: {training_state['grad_norm']}",
666
+ f"Memory usage: {training_state['memory']}"
667
+ ])
668
+
669
+ if training_state["error_message"]:
670
+ status_text.append(f"\nError: {training_state['error_message']}")
671
+
672
+ updates["status_box"] = "\n".join(status_text)
673
+
674
+ # Add current task information to the dedicated box
675
+ if training_state.get("current_task"):
676
+ updates["current_task_box"] = training_state["current_task"]
677
+ else:
678
+ updates["current_task_box"] = "No active task" if training_state["status"] != "training" else "Waiting for task information..."
679
+
680
+ return updates
681
+
682
+ def handle_pause_resume(self):
683
+ """Handle pause/resume button click"""
684
+ status, _, _ = self.get_latest_status_message_and_logs()
685
+
686
+ if status == "paused":
687
+ self.app.trainer.resume_training()
688
+ else:
689
+ self.app.trainer.pause_training()
690
+
691
+ # Return the updates separately for text and buttons
692
+ return (*self.get_status_updates(), *self.get_button_updates())
693
+
694
+ def handle_stop(self):
695
+ """Handle stop button click"""
696
+ self.app.trainer.stop_training()
697
+
698
+ # Return the updates separately for text and buttons
699
+ return (*self.get_status_updates(), *self.get_button_updates())
vms/ui/video_trainer_ui.py CHANGED
@@ -9,7 +9,12 @@ from ..services import TrainingService, CaptioningService, SplittingService, Imp
9
  from ..config import (
10
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
11
  TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
12
- MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES
 
 
 
 
 
13
  )
14
  from ..utils import count_media_files, format_media_title, TrainingLogParser
15
  from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
@@ -92,7 +97,7 @@ class VideoTrainerUI:
92
  self.tabs["train_tab"].components["training_type"],
93
  self.tabs["train_tab"].components["lora_rank"],
94
  self.tabs["train_tab"].components["lora_alpha"],
95
- self.tabs["train_tab"].components["num_epochs"],
96
  self.tabs["train_tab"].components["batch_size"],
97
  self.tabs["train_tab"].components["learning_rate"],
98
  self.tabs["train_tab"].components["save_iterations"],
@@ -104,31 +109,33 @@ class VideoTrainerUI:
104
 
105
  def _add_timers(self):
106
  """Add auto-refresh timers to the UI"""
107
- # Status update timer (every 1 second)
108
  status_timer = gr.Timer(value=1)
 
 
 
 
 
 
 
 
109
 
110
- # Use a safer approach - check if the component exists before using it
111
- outputs = [
112
- self.tabs["train_tab"].components["status_box"],
113
- self.tabs["train_tab"].components["log_box"],
114
  self.tabs["train_tab"].components["start_btn"],
115
  self.tabs["train_tab"].components["stop_btn"]
116
  ]
117
 
118
- # Add current_task_box component
119
- if "current_task_box" in self.tabs["train_tab"].components:
120
- outputs.append(self.tabs["train_tab"].components["current_task_box"])
121
-
122
- # Add delete_checkpoints_btn only if it exists
123
  if "delete_checkpoints_btn" in self.tabs["train_tab"].components:
124
- outputs.append(self.tabs["train_tab"].components["delete_checkpoints_btn"])
125
- else:
126
- # Add pause_resume_btn as fallback
127
- outputs.append(self.tabs["train_tab"].components["pause_resume_btn"])
128
 
129
- status_timer.tick(
130
- fn=self.tabs["train_tab"].get_latest_status_message_logs_and_button_labels,
131
- outputs=outputs
132
  )
133
 
134
  # Dataset refresh timer (every 5 seconds)
@@ -175,6 +182,11 @@ class VideoTrainerUI:
175
  if "model_type" in recovery_ui:
176
  model_type_value = recovery_ui["model_type"]
177
 
 
 
 
 
 
178
  # If it's an internal name, convert to display name
179
  if model_type_value not in MODEL_TYPES:
180
  # Find the display name for this internal model type
@@ -201,7 +213,7 @@ class VideoTrainerUI:
201
  ui_state["training_type"] = training_type_value
202
 
203
  # Copy other parameters
204
- for param in ["lora_rank", "lora_alpha", "num_epochs",
205
  "batch_size", "learning_rate", "save_iterations", "training_preset"]:
206
  if param in recovery_ui:
207
  ui_state[param] = recovery_ui[param]
@@ -216,31 +228,55 @@ class VideoTrainerUI:
216
  # Load values (potentially with recovery updates applied)
217
  ui_state = self.load_ui_values()
218
 
219
- # Ensure model_type is a display name, not internal name
220
  model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
 
 
 
 
 
 
221
  if model_type_val not in MODEL_TYPES:
222
- # Convert from internal to display name
 
223
  for display_name, internal_name in MODEL_TYPES.items():
224
  if internal_name == model_type_val:
225
  model_type_val = display_name
 
226
  break
 
 
 
 
227
 
228
- # Ensure training_type is a display name, not internal name
229
  training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
230
  if training_type_val not in TRAINING_TYPES:
231
- # Convert from internal to display name
 
232
  for display_name, internal_name in TRAINING_TYPES.items():
233
  if internal_name == training_type_val:
234
  training_type_val = display_name
 
235
  break
 
 
 
 
236
 
 
237
  training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
238
- lora_rank_val = ui_state.get("lora_rank", "128")
239
- lora_alpha_val = ui_state.get("lora_alpha", "128")
240
- num_epochs_val = int(ui_state.get("num_epochs", 70))
241
- batch_size_val = int(ui_state.get("batch_size", 1))
242
- learning_rate_val = float(ui_state.get("learning_rate", 3e-5))
243
- save_iterations_val = int(ui_state.get("save_iterations", 500))
 
 
 
 
 
244
 
245
  # Initial current task value
246
  current_task_val = ""
@@ -259,7 +295,7 @@ class VideoTrainerUI:
259
  training_type_val,
260
  lora_rank_val,
261
  lora_alpha_val,
262
- num_epochs_val,
263
  batch_size_val,
264
  learning_rate_val,
265
  save_iterations_val,
@@ -275,12 +311,12 @@ class VideoTrainerUI:
275
  ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
276
  ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
277
  ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
278
- ui_state.get("lora_rank", "128"),
279
- ui_state.get("lora_alpha", "128"),
280
- ui_state.get("num_epochs", 70),
281
- ui_state.get("batch_size", 1),
282
- ui_state.get("learning_rate", 3e-5),
283
- ui_state.get("save_iterations", 500)
284
  )
285
 
286
  def update_ui_state(self, **kwargs):
@@ -296,12 +332,12 @@ class VideoTrainerUI:
296
  ui_state = self.trainer.load_ui_state()
297
 
298
  # Ensure proper type conversion for numeric values
299
- ui_state["lora_rank"] = ui_state.get("lora_rank", "128")
300
- ui_state["lora_alpha"] = ui_state.get("lora_alpha", "128")
301
- ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
302
- ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
303
- ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
304
- ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
305
 
306
  return ui_state
307
 
 
9
  from ..config import (
10
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
11
  TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
12
+ MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
13
+ DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
14
+ DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
15
+ DEFAULT_LEARNING_RATE,
16
+ DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
17
+ DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
18
  )
19
  from ..utils import count_media_files, format_media_title, TrainingLogParser
20
  from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
 
97
  self.tabs["train_tab"].components["training_type"],
98
  self.tabs["train_tab"].components["lora_rank"],
99
  self.tabs["train_tab"].components["lora_alpha"],
100
+ self.tabs["train_tab"].components["train_steps"],
101
  self.tabs["train_tab"].components["batch_size"],
102
  self.tabs["train_tab"].components["learning_rate"],
103
  self.tabs["train_tab"].components["save_iterations"],
 
109
 
110
  def _add_timers(self):
111
  """Add auto-refresh timers to the UI"""
112
+ # Status update timer for text components (every 1 second)
113
  status_timer = gr.Timer(value=1)
114
+ status_timer.tick(
115
+ fn=self.tabs["train_tab"].get_status_updates, # Use a new function that returns appropriate updates
116
+ outputs=[
117
+ self.tabs["train_tab"].components["status_box"],
118
+ self.tabs["train_tab"].components["log_box"],
119
+ self.tabs["train_tab"].components["current_task_box"] if "current_task_box" in self.tabs["train_tab"].components else None
120
+ ]
121
+ )
122
 
123
+ # Button update timer for button components (every 1 second)
124
+ button_timer = gr.Timer(value=1)
125
+ button_outputs = [
 
126
  self.tabs["train_tab"].components["start_btn"],
127
  self.tabs["train_tab"].components["stop_btn"]
128
  ]
129
 
130
+ # Add delete_checkpoints_btn or pause_resume_btn as the third button
 
 
 
 
131
  if "delete_checkpoints_btn" in self.tabs["train_tab"].components:
132
+ button_outputs.append(self.tabs["train_tab"].components["delete_checkpoints_btn"])
133
+ elif "pause_resume_btn" in self.tabs["train_tab"].components:
134
+ button_outputs.append(self.tabs["train_tab"].components["pause_resume_btn"])
 
135
 
136
+ button_timer.tick(
137
+ fn=self.tabs["train_tab"].get_button_updates, # Use a new function for button-specific updates
138
+ outputs=button_outputs
139
  )
140
 
141
  # Dataset refresh timer (every 5 seconds)
 
182
  if "model_type" in recovery_ui:
183
  model_type_value = recovery_ui["model_type"]
184
 
185
+ # Remove " (LoRA)" suffix if present
186
+ if " (LoRA)" in model_type_value:
187
+ model_type_value = model_type_value.replace(" (LoRA)", "")
188
+ logger.info(f"Removed (LoRA) suffix from model type: {model_type_value}")
189
+
190
  # If it's an internal name, convert to display name
191
  if model_type_value not in MODEL_TYPES:
192
  # Find the display name for this internal model type
 
213
  ui_state["training_type"] = training_type_value
214
 
215
  # Copy other parameters
216
+ for param in ["lora_rank", "lora_alpha", "train_steps",
217
  "batch_size", "learning_rate", "save_iterations", "training_preset"]:
218
  if param in recovery_ui:
219
  ui_state[param] = recovery_ui[param]
 
228
  # Load values (potentially with recovery updates applied)
229
  ui_state = self.load_ui_values()
230
 
231
+ # Ensure model_type is a valid display name
232
  model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
233
+ # Remove " (LoRA)" suffix if present
234
+ if " (LoRA)" in model_type_val:
235
+ model_type_val = model_type_val.replace(" (LoRA)", "")
236
+ logger.info(f"Removed (LoRA) suffix from model type: {model_type_val}")
237
+
238
+ # Ensure it's a valid model type in the dropdown
239
  if model_type_val not in MODEL_TYPES:
240
+ # Convert from internal to display name or use default
241
+ model_type_found = False
242
  for display_name, internal_name in MODEL_TYPES.items():
243
  if internal_name == model_type_val:
244
  model_type_val = display_name
245
+ model_type_found = True
246
  break
247
+ # If still not found, use the first model type
248
+ if not model_type_found:
249
+ model_type_val = list(MODEL_TYPES.keys())[0]
250
+ logger.warning(f"Invalid model type '{model_type_val}', using default: {model_type_val}")
251
 
252
+ # Ensure training_type is a valid display name
253
  training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
254
  if training_type_val not in TRAINING_TYPES:
255
+ # Convert from internal to display name or use default
256
+ training_type_found = False
257
  for display_name, internal_name in TRAINING_TYPES.items():
258
  if internal_name == training_type_val:
259
  training_type_val = display_name
260
+ training_type_found = True
261
  break
262
+ # If still not found, use the first training type
263
+ if not training_type_found:
264
+ training_type_val = list(TRAINING_TYPES.keys())[0]
265
+ logger.warning(f"Invalid training type '{training_type_val}', using default: {training_type_val}")
266
 
267
+ # Validate training preset
268
  training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
269
+ if training_preset not in TRAINING_PRESETS:
270
+ training_preset = list(TRAINING_PRESETS.keys())[0]
271
+ logger.warning(f"Invalid training preset '{training_preset}', using default: {training_preset}")
272
+
273
+ # Rest of the function remains unchanged
274
+ lora_rank_val = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
275
+ lora_alpha_val = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
276
+ train_steps_val = int(ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS))
277
+ batch_size_val = int(ui_state.get("batch_size", DEFAULT_BATCH_SIZE))
278
+ learning_rate_val = float(ui_state.get("learning_rate", DEFAULT_LEARNING_RATE))
279
+ save_iterations_val = int(ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS))
280
 
281
  # Initial current task value
282
  current_task_val = ""
 
295
  training_type_val,
296
  lora_rank_val,
297
  lora_alpha_val,
298
+ train_steps_val,
299
  batch_size_val,
300
  learning_rate_val,
301
  save_iterations_val,
 
311
  ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
312
  ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
313
  ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
314
+ ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR),
315
+ ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR),
316
+ ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
317
+ ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
318
+ ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
319
+ ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)
320
  )
321
 
322
  def update_ui_state(self, **kwargs):
 
332
  ui_state = self.trainer.load_ui_state()
333
 
334
  # Ensure proper type conversion for numeric values
335
+ ui_state["lora_rank"] = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
336
+ ui_state["lora_alpha"] = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
337
+ ui_state["train_steps"] = int(ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS))
338
+ ui_state["batch_size"] = int(ui_state.get("batch_size", DEFAULT_BATCH_SIZE))
339
+ ui_state["learning_rate"] = float(ui_state.get("learning_rate", DEFAULT_LEARNING_RATE))
340
+ ui_state["save_iterations"] = int(ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS))
341
 
342
  return ui_state
343