Spaces:
Running
Running
Commit
·
c6546ad
1
Parent(s):
38cfbff
cleaning code
Browse files- vms/config.py +116 -89
- vms/services/trainer.py +90 -47
- vms/tabs/train_tab.py +153 -177
- vms/ui/video_trainer_ui.py +79 -43
vms/config.py
CHANGED
@@ -58,9 +58,9 @@ JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
|
|
58 |
|
59 |
# Expanded model types to include Wan-2.1-T2V
|
60 |
MODEL_TYPES = {
|
61 |
-
"HunyuanVideo
|
62 |
-
"LTX-Video
|
63 |
-
"Wan-2.1-T2V
|
64 |
}
|
65 |
|
66 |
# Training types
|
@@ -69,6 +69,23 @@ TRAINING_TYPES = {
|
|
69 |
"Full Finetune": "full-finetune"
|
70 |
}
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
# it is best to use resolutions that are powers of 8
|
74 |
# The resolution should be divisible by 32
|
@@ -87,39 +104,49 @@ MEDIUM_19_9_RATIO_HEIGHT = 512 # 32 * 16
|
|
87 |
NB_FRAMES_1 = 1 # 1
|
88 |
NB_FRAMES_9 = 8 + 1 # 8 + 1
|
89 |
NB_FRAMES_17 = 8 * 2 + 1 # 16 + 1
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
# 256 isn't a lot by the way, especially with 60 FPS videos..
|
104 |
# can we crank it and put more frames in here?
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
SMALL_TRAINING_BUCKETS = [
|
107 |
(NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
|
108 |
(NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
|
109 |
(NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
|
110 |
-
(
|
111 |
-
(
|
112 |
-
(
|
113 |
-
(
|
114 |
-
(
|
115 |
-
(
|
116 |
-
(
|
117 |
-
(
|
118 |
-
(
|
119 |
-
(
|
120 |
-
(
|
121 |
-
(
|
122 |
-
(
|
123 |
]
|
124 |
|
125 |
MEDIUM_19_9_RATIO_WIDTH = 928 # 32 * 29
|
@@ -129,19 +156,19 @@ MEDIUM_19_9_RATIO_BUCKETS = [
|
|
129 |
(NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
|
130 |
(NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
|
131 |
(NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
|
132 |
-
(
|
133 |
-
(
|
134 |
-
(
|
135 |
-
(
|
136 |
-
(
|
137 |
-
(
|
138 |
-
(
|
139 |
-
(
|
140 |
-
(
|
141 |
-
(
|
142 |
-
(
|
143 |
-
(
|
144 |
-
(
|
145 |
]
|
146 |
|
147 |
# Updated training presets to include Wan-2.1-T2V and support both LoRA and full-finetune
|
@@ -149,24 +176,24 @@ TRAINING_PRESETS = {
|
|
149 |
"HunyuanVideo (normal)": {
|
150 |
"model_type": "hunyuan_video",
|
151 |
"training_type": "lora",
|
152 |
-
"lora_rank":
|
153 |
-
"lora_alpha":
|
154 |
-
"
|
155 |
-
"batch_size":
|
156 |
"learning_rate": 2e-5,
|
157 |
-
"save_iterations":
|
158 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
159 |
"flow_weighting_scheme": "none"
|
160 |
},
|
161 |
"LTX-Video (normal)": {
|
162 |
"model_type": "ltx_video",
|
163 |
"training_type": "lora",
|
164 |
-
"lora_rank":
|
165 |
-
"lora_alpha":
|
166 |
-
"
|
167 |
-
"batch_size":
|
168 |
-
"learning_rate":
|
169 |
-
"save_iterations":
|
170 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
171 |
"flow_weighting_scheme": "logit_normal"
|
172 |
},
|
@@ -174,21 +201,21 @@ TRAINING_PRESETS = {
|
|
174 |
"model_type": "ltx_video",
|
175 |
"training_type": "lora",
|
176 |
"lora_rank": "256",
|
177 |
-
"lora_alpha":
|
178 |
-
"
|
179 |
-
"batch_size":
|
180 |
-
"learning_rate":
|
181 |
-
"save_iterations":
|
182 |
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
183 |
"flow_weighting_scheme": "logit_normal"
|
184 |
},
|
185 |
"LTX-Video (Full Finetune)": {
|
186 |
"model_type": "ltx_video",
|
187 |
"training_type": "full-finetune",
|
188 |
-
"
|
189 |
-
"batch_size":
|
190 |
-
"learning_rate":
|
191 |
-
"save_iterations":
|
192 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
193 |
"flow_weighting_scheme": "logit_normal"
|
194 |
},
|
@@ -197,10 +224,10 @@ TRAINING_PRESETS = {
|
|
197 |
"training_type": "lora",
|
198 |
"lora_rank": "32",
|
199 |
"lora_alpha": "32",
|
200 |
-
"
|
201 |
-
"batch_size":
|
202 |
"learning_rate": 5e-5,
|
203 |
-
"save_iterations":
|
204 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
205 |
"flow_weighting_scheme": "logit_normal"
|
206 |
},
|
@@ -209,10 +236,10 @@ TRAINING_PRESETS = {
|
|
209 |
"training_type": "lora",
|
210 |
"lora_rank": "64",
|
211 |
"lora_alpha": "64",
|
212 |
-
"
|
213 |
-
"batch_size":
|
214 |
-
"learning_rate":
|
215 |
-
"save_iterations":
|
216 |
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
217 |
"flow_weighting_scheme": "logit_normal"
|
218 |
}
|
@@ -244,7 +271,7 @@ class TrainingConfig:
|
|
244 |
id_token: Optional[str] = None
|
245 |
video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: SMALL_TRAINING_BUCKETS)
|
246 |
video_reshape_mode: str = "center"
|
247 |
-
caption_dropout_p: float =
|
248 |
caption_dropout_technique: str = "empty"
|
249 |
precompute_conditions: bool = False
|
250 |
|
@@ -257,16 +284,16 @@ class TrainingConfig:
|
|
257 |
|
258 |
# Training arguments
|
259 |
training_type: str = "lora"
|
260 |
-
seed: int =
|
261 |
mixed_precision: str = "bf16"
|
262 |
batch_size: int = 1
|
263 |
-
|
264 |
-
lora_rank: int =
|
265 |
-
lora_alpha: int =
|
266 |
target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"])
|
267 |
gradient_accumulation_steps: int = 1
|
268 |
gradient_checkpointing: bool = True
|
269 |
-
checkpointing_steps: int =
|
270 |
checkpointing_limit: Optional[int] = 2
|
271 |
resume_from_checkpoint: Optional[str] = None
|
272 |
enable_slicing: bool = True
|
@@ -300,15 +327,15 @@ class TrainingConfig:
|
|
300 |
data_root=data_path,
|
301 |
output_dir=output_path,
|
302 |
batch_size=1,
|
303 |
-
|
304 |
lr=2e-5,
|
305 |
gradient_checkpointing=True,
|
306 |
id_token="afkx",
|
307 |
gradient_accumulation_steps=1,
|
308 |
-
lora_rank=
|
309 |
-
lora_alpha=
|
310 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
311 |
-
caption_dropout_p=
|
312 |
flow_weighting_scheme="none", # Hunyuan specific
|
313 |
training_type="lora"
|
314 |
)
|
@@ -322,15 +349,15 @@ class TrainingConfig:
|
|
322 |
data_root=data_path,
|
323 |
output_dir=output_path,
|
324 |
batch_size=1,
|
325 |
-
|
326 |
-
lr=
|
327 |
gradient_checkpointing=True,
|
328 |
id_token="BW_STYLE",
|
329 |
gradient_accumulation_steps=4,
|
330 |
-
lora_rank=
|
331 |
-
lora_alpha=
|
332 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
333 |
-
caption_dropout_p=
|
334 |
flow_weighting_scheme="logit_normal", # LTX specific
|
335 |
training_type="lora"
|
336 |
)
|
@@ -344,13 +371,13 @@ class TrainingConfig:
|
|
344 |
data_root=data_path,
|
345 |
output_dir=output_path,
|
346 |
batch_size=1,
|
347 |
-
|
348 |
lr=1e-5,
|
349 |
gradient_checkpointing=True,
|
350 |
id_token="BW_STYLE",
|
351 |
gradient_accumulation_steps=1,
|
352 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
353 |
-
caption_dropout_p=
|
354 |
flow_weighting_scheme="logit_normal", # LTX specific
|
355 |
training_type="full-finetune"
|
356 |
)
|
@@ -364,7 +391,7 @@ class TrainingConfig:
|
|
364 |
data_root=data_path,
|
365 |
output_dir=output_path,
|
366 |
batch_size=1,
|
367 |
-
|
368 |
lr=5e-5,
|
369 |
gradient_checkpointing=True,
|
370 |
id_token=None, # Default is no ID token for Wan
|
@@ -373,7 +400,7 @@ class TrainingConfig:
|
|
373 |
lora_alpha=32,
|
374 |
target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
|
375 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
376 |
-
caption_dropout_p=
|
377 |
flow_weighting_scheme="logit_normal", # Wan specific
|
378 |
training_type="lora"
|
379 |
)
|
@@ -428,7 +455,7 @@ class TrainingConfig:
|
|
428 |
#args.extend(["--mixed_precision", self.mixed_precision])
|
429 |
|
430 |
args.extend(["--batch_size", str(self.batch_size)])
|
431 |
-
args.extend(["--train_steps", str(self.
|
432 |
|
433 |
# LoRA specific arguments
|
434 |
if self.training_type == "lora":
|
|
|
58 |
|
59 |
# Expanded model types to include Wan-2.1-T2V
|
60 |
MODEL_TYPES = {
|
61 |
+
"HunyuanVideo": "hunyuan_video",
|
62 |
+
"LTX-Video": "ltx_video",
|
63 |
+
"Wan-2.1-T2V": "wan"
|
64 |
}
|
65 |
|
66 |
# Training types
|
|
|
69 |
"Full Finetune": "full-finetune"
|
70 |
}
|
71 |
|
72 |
+
DEFAULT_SEED = 42
|
73 |
+
|
74 |
+
DEFAULT_NB_TRAINING_STEPS = 1000
|
75 |
+
|
76 |
+
DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS = 200
|
77 |
+
|
78 |
+
DEFAULT_LORA_RANK = 128
|
79 |
+
DEFAULT_LORA_RANK_STR = str(DEFAULT_LORA_RANK)
|
80 |
+
|
81 |
+
DEFAULT_LORA_ALPHA = 128
|
82 |
+
DEFAULT_LORA_ALPHA_STR = str(DEFAULT_LORA_ALPHA)
|
83 |
+
|
84 |
+
DEFAULT_CAPTION_DROPOUT_P = 0.05
|
85 |
+
|
86 |
+
DEFAULT_BATCH_SIZE = 1
|
87 |
+
|
88 |
+
DEFAULT_LEARNING_RATE = 3e-5
|
89 |
|
90 |
# it is best to use resolutions that are powers of 8
|
91 |
# The resolution should be divisible by 32
|
|
|
104 |
NB_FRAMES_1 = 1 # 1
|
105 |
NB_FRAMES_9 = 8 + 1 # 8 + 1
|
106 |
NB_FRAMES_17 = 8 * 2 + 1 # 16 + 1
|
107 |
+
NB_FRAMES_33 = 8 * 4 + 1 # 32 + 1
|
108 |
+
NB_FRAMES_49 = 8 * 6 + 1 # 48 + 1
|
109 |
+
NB_FRAMES_65 = 8 * 8 + 1 # 64 + 1
|
110 |
+
NB_FRAMES_81 = 8 * 10 + 1 # 80 + 1
|
111 |
+
NB_FRAMES_97 = 8 * 12 + 1 # 96 + 1
|
112 |
+
NB_FRAMES_113 = 8 * 14 + 1 # 112 + 1
|
113 |
+
NB_FRAMES_129 = 8 * 16 + 1 # 128 + 1
|
114 |
+
NB_FRAMES_145 = 8 * 18 + 1 # 144 + 1
|
115 |
+
NB_FRAMES_161 = 8 * 20 + 1 # 160 + 1
|
116 |
+
NB_FRAMES_177 = 8 * 22 + 1 # 176 + 1
|
117 |
+
NB_FRAMES_193 = 8 * 24 + 1 # 192 + 1
|
118 |
+
NB_FRAMES_225 = 8 * 28 + 1 # 224 + 1
|
119 |
+
NB_FRAMES_257 = 8 * 32 + 1 # 256 + 1
|
120 |
# 256 isn't a lot by the way, especially with 60 FPS videos..
|
121 |
# can we crank it and put more frames in here?
|
122 |
|
123 |
+
NB_FRAMES_273 = 8 * 34 + 1 # 272 + 1
|
124 |
+
NB_FRAMES_289 = 8 * 36 + 1 # 288 + 1
|
125 |
+
NB_FRAMES_305 = 8 * 38 + 1 # 304 + 1
|
126 |
+
NB_FRAMES_321 = 8 * 40 + 1 # 320 + 1
|
127 |
+
NB_FRAMES_337 = 8 * 42 + 1 # 336 + 1
|
128 |
+
NB_FRAMES_353 = 8 * 44 + 1 # 352 + 1
|
129 |
+
NB_FRAMES_369 = 8 * 46 + 1 # 368 + 1
|
130 |
+
NB_FRAMES_385 = 8 * 48 + 1 # 384 + 1
|
131 |
+
NB_FRAMES_401 = 8 * 50 + 1 # 400 + 1
|
132 |
+
|
133 |
SMALL_TRAINING_BUCKETS = [
|
134 |
(NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
|
135 |
(NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
|
136 |
(NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
|
137 |
+
(NB_FRAMES_33, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
|
138 |
+
(NB_FRAMES_49, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
|
139 |
+
(NB_FRAMES_65, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
|
140 |
+
(NB_FRAMES_81, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
|
141 |
+
(NB_FRAMES_97, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
|
142 |
+
(NB_FRAMES_113, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
|
143 |
+
(NB_FRAMES_129, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
|
144 |
+
(NB_FRAMES_145, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
|
145 |
+
(NB_FRAMES_161, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
|
146 |
+
(NB_FRAMES_177, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
|
147 |
+
(NB_FRAMES_193, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
|
148 |
+
(NB_FRAMES_225, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
|
149 |
+
(NB_FRAMES_257, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
|
150 |
]
|
151 |
|
152 |
MEDIUM_19_9_RATIO_WIDTH = 928 # 32 * 29
|
|
|
156 |
(NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
|
157 |
(NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
|
158 |
(NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
|
159 |
+
(NB_FRAMES_33, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
|
160 |
+
(NB_FRAMES_49, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
|
161 |
+
(NB_FRAMES_65, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
|
162 |
+
(NB_FRAMES_81, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
|
163 |
+
(NB_FRAMES_97, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
|
164 |
+
(NB_FRAMES_113, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
|
165 |
+
(NB_FRAMES_129, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
|
166 |
+
(NB_FRAMES_145, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
|
167 |
+
(NB_FRAMES_161, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
|
168 |
+
(NB_FRAMES_177, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
|
169 |
+
(NB_FRAMES_193, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
|
170 |
+
(NB_FRAMES_225, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
|
171 |
+
(NB_FRAMES_257, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
|
172 |
]
|
173 |
|
174 |
# Updated training presets to include Wan-2.1-T2V and support both LoRA and full-finetune
|
|
|
176 |
"HunyuanVideo (normal)": {
|
177 |
"model_type": "hunyuan_video",
|
178 |
"training_type": "lora",
|
179 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
180 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
181 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
182 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
183 |
"learning_rate": 2e-5,
|
184 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
185 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
186 |
"flow_weighting_scheme": "none"
|
187 |
},
|
188 |
"LTX-Video (normal)": {
|
189 |
"model_type": "ltx_video",
|
190 |
"training_type": "lora",
|
191 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
192 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
193 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
194 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
195 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
196 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
197 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
198 |
"flow_weighting_scheme": "logit_normal"
|
199 |
},
|
|
|
201 |
"model_type": "ltx_video",
|
202 |
"training_type": "lora",
|
203 |
"lora_rank": "256",
|
204 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
205 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
206 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
207 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
208 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
209 |
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
210 |
"flow_weighting_scheme": "logit_normal"
|
211 |
},
|
212 |
"LTX-Video (Full Finetune)": {
|
213 |
"model_type": "ltx_video",
|
214 |
"training_type": "full-finetune",
|
215 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
216 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
217 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
218 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
219 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
220 |
"flow_weighting_scheme": "logit_normal"
|
221 |
},
|
|
|
224 |
"training_type": "lora",
|
225 |
"lora_rank": "32",
|
226 |
"lora_alpha": "32",
|
227 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
228 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
229 |
"learning_rate": 5e-5,
|
230 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
231 |
"training_buckets": SMALL_TRAINING_BUCKETS,
|
232 |
"flow_weighting_scheme": "logit_normal"
|
233 |
},
|
|
|
236 |
"training_type": "lora",
|
237 |
"lora_rank": "64",
|
238 |
"lora_alpha": "64",
|
239 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
240 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
241 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
242 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
243 |
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
|
244 |
"flow_weighting_scheme": "logit_normal"
|
245 |
}
|
|
|
271 |
id_token: Optional[str] = None
|
272 |
video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: SMALL_TRAINING_BUCKETS)
|
273 |
video_reshape_mode: str = "center"
|
274 |
+
caption_dropout_p: float = DEFAULT_CAPTION_DROPOUT_P
|
275 |
caption_dropout_technique: str = "empty"
|
276 |
precompute_conditions: bool = False
|
277 |
|
|
|
284 |
|
285 |
# Training arguments
|
286 |
training_type: str = "lora"
|
287 |
+
seed: int = DEFAULT_SEED
|
288 |
mixed_precision: str = "bf16"
|
289 |
batch_size: int = 1
|
290 |
+
train_step: int = DEFAULT_NB_TRAINING_STEPS
|
291 |
+
lora_rank: int = DEFAULT_LORA_RANK
|
292 |
+
lora_alpha: int = DEFAULT_LORA_ALPHA
|
293 |
target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"])
|
294 |
gradient_accumulation_steps: int = 1
|
295 |
gradient_checkpointing: bool = True
|
296 |
+
checkpointing_steps: int = DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS
|
297 |
checkpointing_limit: Optional[int] = 2
|
298 |
resume_from_checkpoint: Optional[str] = None
|
299 |
enable_slicing: bool = True
|
|
|
327 |
data_root=data_path,
|
328 |
output_dir=output_path,
|
329 |
batch_size=1,
|
330 |
+
train_steps=DEFAULT_NB_TRAINING_STEPS,
|
331 |
lr=2e-5,
|
332 |
gradient_checkpointing=True,
|
333 |
id_token="afkx",
|
334 |
gradient_accumulation_steps=1,
|
335 |
+
lora_rank=DEFAULT_LORA_RANK,
|
336 |
+
lora_alpha=DEFAULT_LORA_ALPHA,
|
337 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
338 |
+
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
339 |
flow_weighting_scheme="none", # Hunyuan specific
|
340 |
training_type="lora"
|
341 |
)
|
|
|
349 |
data_root=data_path,
|
350 |
output_dir=output_path,
|
351 |
batch_size=1,
|
352 |
+
train_steps=DEFAULT_NB_TRAINING_STEPS,
|
353 |
+
lr=DEFAULT_LEARNING_RATE,
|
354 |
gradient_checkpointing=True,
|
355 |
id_token="BW_STYLE",
|
356 |
gradient_accumulation_steps=4,
|
357 |
+
lora_rank=DEFAULT_LORA_RANK,
|
358 |
+
lora_alpha=DEFAULT_LORA_ALPHA,
|
359 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
360 |
+
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
361 |
flow_weighting_scheme="logit_normal", # LTX specific
|
362 |
training_type="lora"
|
363 |
)
|
|
|
371 |
data_root=data_path,
|
372 |
output_dir=output_path,
|
373 |
batch_size=1,
|
374 |
+
train_steps=DEFAULT_NB_TRAINING_STEPS,
|
375 |
lr=1e-5,
|
376 |
gradient_checkpointing=True,
|
377 |
id_token="BW_STYLE",
|
378 |
gradient_accumulation_steps=1,
|
379 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
380 |
+
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
381 |
flow_weighting_scheme="logit_normal", # LTX specific
|
382 |
training_type="full-finetune"
|
383 |
)
|
|
|
391 |
data_root=data_path,
|
392 |
output_dir=output_path,
|
393 |
batch_size=1,
|
394 |
+
train_steps=DEFAULT_NB_TRAINING_STEPS,
|
395 |
lr=5e-5,
|
396 |
gradient_checkpointing=True,
|
397 |
id_token=None, # Default is no ID token for Wan
|
|
|
400 |
lora_alpha=32,
|
401 |
target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], # Wan-specific target modules
|
402 |
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
|
403 |
+
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P,
|
404 |
flow_weighting_scheme="logit_normal", # Wan specific
|
405 |
training_type="lora"
|
406 |
)
|
|
|
455 |
#args.extend(["--mixed_precision", self.mixed_precision])
|
456 |
|
457 |
args.extend(["--batch_size", str(self.batch_size)])
|
458 |
+
args.extend(["--train_steps", str(self.train_steps)])
|
459 |
|
460 |
# LoRA specific arguments
|
461 |
if self.training_type == "lora":
|
vms/services/trainer.py
CHANGED
@@ -23,7 +23,12 @@ from huggingface_hub import upload_folder, create_repo
|
|
23 |
from ..config import (
|
24 |
TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
|
25 |
STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
|
26 |
-
MODEL_TYPES, TRAINING_TYPES
|
|
|
|
|
|
|
|
|
|
|
27 |
)
|
28 |
from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
|
29 |
|
@@ -111,18 +116,19 @@ class TrainingService:
|
|
111 |
except Exception as e:
|
112 |
logger.error(f"Error saving UI state: {str(e)}")
|
113 |
|
|
|
114 |
def load_ui_state(self) -> Dict[str, Any]:
|
115 |
"""Load saved UI state"""
|
116 |
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
117 |
default_state = {
|
118 |
"model_type": list(MODEL_TYPES.keys())[0],
|
119 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
120 |
-
"lora_rank":
|
121 |
-
"lora_alpha":
|
122 |
-
"
|
123 |
-
"batch_size":
|
124 |
-
"learning_rate":
|
125 |
-
"save_iterations":
|
126 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
127 |
}
|
128 |
|
@@ -145,9 +151,14 @@ class TrainingService:
|
|
145 |
|
146 |
saved_state = json.loads(file_content)
|
147 |
|
|
|
|
|
|
|
|
|
|
|
148 |
# Convert numeric values to appropriate types
|
149 |
-
if "
|
150 |
-
saved_state["
|
151 |
if "batch_size" in saved_state:
|
152 |
saved_state["batch_size"] = int(saved_state["batch_size"])
|
153 |
if "learning_rate" in saved_state:
|
@@ -158,6 +169,40 @@ class TrainingService:
|
|
158 |
# Make sure we have all keys (in case structure changed)
|
159 |
merged_state = default_state.copy()
|
160 |
merged_state.update(saved_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
return merged_state
|
162 |
except json.JSONDecodeError as e:
|
163 |
logger.error(f"Error parsing UI state JSON: {str(e)}")
|
@@ -176,12 +221,12 @@ class TrainingService:
|
|
176 |
default_state = {
|
177 |
"model_type": list(MODEL_TYPES.keys())[0],
|
178 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
179 |
-
"lora_rank":
|
180 |
-
"lora_alpha":
|
181 |
-
"
|
182 |
-
"batch_size":
|
183 |
-
"learning_rate":
|
184 |
-
"save_iterations":
|
185 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
186 |
}
|
187 |
self.save_ui_state(default_state)
|
@@ -209,12 +254,12 @@ class TrainingService:
|
|
209 |
default_state = {
|
210 |
"model_type": list(MODEL_TYPES.keys())[0],
|
211 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
212 |
-
"lora_rank":
|
213 |
-
"lora_alpha":
|
214 |
-
"
|
215 |
-
"batch_size":
|
216 |
-
"learning_rate":
|
217 |
-
"save_iterations":
|
218 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
219 |
}
|
220 |
self.save_ui_state(default_state)
|
@@ -361,7 +406,7 @@ class TrainingService:
|
|
361 |
model_type: str,
|
362 |
lora_rank: str,
|
363 |
lora_alpha: str,
|
364 |
-
|
365 |
batch_size: int,
|
366 |
learning_rate: float,
|
367 |
save_iterations: int,
|
@@ -508,7 +553,7 @@ class TrainingService:
|
|
508 |
return error_msg, "Unsupported model"
|
509 |
|
510 |
# Update with UI parameters
|
511 |
-
config.
|
512 |
config.batch_size = int(batch_size)
|
513 |
config.lr = float(learning_rate)
|
514 |
config.checkpointing_steps = int(save_iterations)
|
@@ -530,11 +575,11 @@ class TrainingService:
|
|
530 |
|
531 |
# Common settings for both models
|
532 |
config.mixed_precision = "bf16"
|
533 |
-
config.seed =
|
534 |
config.gradient_checkpointing = True
|
535 |
config.enable_slicing = True
|
536 |
config.enable_tiling = True
|
537 |
-
config.caption_dropout_p =
|
538 |
|
539 |
validation_error = self.validate_training_config(config, model_type)
|
540 |
if validation_error:
|
@@ -626,7 +671,7 @@ class TrainingService:
|
|
626 |
"training_type": training_type,
|
627 |
"lora_rank": lora_rank,
|
628 |
"lora_alpha": lora_alpha,
|
629 |
-
"
|
630 |
"batch_size": batch_size,
|
631 |
"learning_rate": learning_rate,
|
632 |
"save_iterations": save_iterations,
|
@@ -635,14 +680,12 @@ class TrainingService:
|
|
635 |
})
|
636 |
|
637 |
# Update initial training status
|
638 |
-
total_steps =
|
639 |
self.save_status(
|
640 |
state='training',
|
641 |
-
epoch=0,
|
642 |
step=0,
|
643 |
total_steps=total_steps,
|
644 |
loss=0.0,
|
645 |
-
total_epochs=num_epochs,
|
646 |
message='Training started',
|
647 |
repo_id=repo_id,
|
648 |
model_type=model_type,
|
@@ -789,12 +832,12 @@ class TrainingService:
|
|
789 |
"params": {
|
790 |
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
791 |
"training_type": TRAINING_TYPES.get(ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])),
|
792 |
-
"lora_rank": ui_state.get("lora_rank",
|
793 |
-
"lora_alpha": ui_state.get("lora_alpha",
|
794 |
-
"
|
795 |
-
"batch_size": ui_state.get("batch_size",
|
796 |
-
"learning_rate": ui_state.get("learning_rate",
|
797 |
-
"save_iterations": ui_state.get("save_iterations",
|
798 |
"preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
799 |
"repo_id": "" # Default empty repo ID
|
800 |
}
|
@@ -853,12 +896,12 @@ class TrainingService:
|
|
853 |
ui_updates.update({
|
854 |
"model_type": model_type_display, # Use the display name for the UI dropdown
|
855 |
"training_type": training_type_display, # Use the display name for training type
|
856 |
-
"lora_rank": params.get('lora_rank',
|
857 |
-
"lora_alpha": params.get('lora_alpha',
|
858 |
-
"
|
859 |
-
"batch_size": params.get('batch_size',
|
860 |
-
"learning_rate": params.get('learning_rate',
|
861 |
-
"save_iterations": params.get('save_iterations',
|
862 |
"training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
863 |
})
|
864 |
|
@@ -872,12 +915,12 @@ class TrainingService:
|
|
872 |
# But keep model_type_display for the UI
|
873 |
result = self.start_training(
|
874 |
model_type=model_type_internal,
|
875 |
-
lora_rank=params.get('lora_rank',
|
876 |
-
lora_alpha=params.get('lora_alpha',
|
877 |
-
|
878 |
-
batch_size=params.get('batch_size',
|
879 |
-
learning_rate=params.get('learning_rate',
|
880 |
-
save_iterations=params.get('save_iterations',
|
881 |
repo_id=params.get('repo_id', ''),
|
882 |
preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
|
883 |
training_type=training_type_internal,
|
|
|
23 |
from ..config import (
|
24 |
TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
|
25 |
STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
|
26 |
+
MODEL_TYPES, TRAINING_TYPES,
|
27 |
+
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
28 |
+
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
29 |
+
DEFAULT_LEARNING_RATE,
|
30 |
+
DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
|
31 |
+
DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
|
32 |
)
|
33 |
from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
|
34 |
|
|
|
116 |
except Exception as e:
|
117 |
logger.error(f"Error saving UI state: {str(e)}")
|
118 |
|
119 |
+
# Additional fix for the load_ui_state method in trainer.py to clean up old values
|
120 |
def load_ui_state(self) -> Dict[str, Any]:
|
121 |
"""Load saved UI state"""
|
122 |
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
123 |
default_state = {
|
124 |
"model_type": list(MODEL_TYPES.keys())[0],
|
125 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
126 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
127 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
128 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
129 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
130 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
131 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
132 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
133 |
}
|
134 |
|
|
|
151 |
|
152 |
saved_state = json.loads(file_content)
|
153 |
|
154 |
+
# Clean up model type if it contains " (LoRA)" suffix
|
155 |
+
if "model_type" in saved_state and " (LoRA)" in saved_state["model_type"]:
|
156 |
+
saved_state["model_type"] = saved_state["model_type"].replace(" (LoRA)", "")
|
157 |
+
logger.info(f"Removed (LoRA) suffix from saved model type: {saved_state['model_type']}")
|
158 |
+
|
159 |
# Convert numeric values to appropriate types
|
160 |
+
if "train_steps" in saved_state:
|
161 |
+
saved_state["train_steps"] = int(saved_state["train_steps"])
|
162 |
if "batch_size" in saved_state:
|
163 |
saved_state["batch_size"] = int(saved_state["batch_size"])
|
164 |
if "learning_rate" in saved_state:
|
|
|
169 |
# Make sure we have all keys (in case structure changed)
|
170 |
merged_state = default_state.copy()
|
171 |
merged_state.update(saved_state)
|
172 |
+
|
173 |
+
# Validate model_type is in available choices
|
174 |
+
if merged_state["model_type"] not in MODEL_TYPES:
|
175 |
+
# Try to map from internal name
|
176 |
+
model_found = False
|
177 |
+
for display_name, internal_name in MODEL_TYPES.items():
|
178 |
+
if internal_name == merged_state["model_type"]:
|
179 |
+
merged_state["model_type"] = display_name
|
180 |
+
model_found = True
|
181 |
+
break
|
182 |
+
# If still not found, use default
|
183 |
+
if not model_found:
|
184 |
+
merged_state["model_type"] = default_state["model_type"]
|
185 |
+
logger.warning(f"Invalid model type in saved state, using default")
|
186 |
+
|
187 |
+
# Validate training_type is in available choices
|
188 |
+
if merged_state["training_type"] not in TRAINING_TYPES:
|
189 |
+
# Try to map from internal name
|
190 |
+
training_found = False
|
191 |
+
for display_name, internal_name in TRAINING_TYPES.items():
|
192 |
+
if internal_name == merged_state["training_type"]:
|
193 |
+
merged_state["training_type"] = display_name
|
194 |
+
training_found = True
|
195 |
+
break
|
196 |
+
# If still not found, use default
|
197 |
+
if not training_found:
|
198 |
+
merged_state["training_type"] = default_state["training_type"]
|
199 |
+
logger.warning(f"Invalid training type in saved state, using default")
|
200 |
+
|
201 |
+
# Validate training_preset is in available choices
|
202 |
+
if merged_state["training_preset"] not in TRAINING_PRESETS:
|
203 |
+
merged_state["training_preset"] = default_state["training_preset"]
|
204 |
+
logger.warning(f"Invalid training preset in saved state, using default")
|
205 |
+
|
206 |
return merged_state
|
207 |
except json.JSONDecodeError as e:
|
208 |
logger.error(f"Error parsing UI state JSON: {str(e)}")
|
|
|
221 |
default_state = {
|
222 |
"model_type": list(MODEL_TYPES.keys())[0],
|
223 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
224 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
225 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
226 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
227 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
228 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
229 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
230 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
231 |
}
|
232 |
self.save_ui_state(default_state)
|
|
|
254 |
default_state = {
|
255 |
"model_type": list(MODEL_TYPES.keys())[0],
|
256 |
"training_type": list(TRAINING_TYPES.keys())[0],
|
257 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
258 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR,
|
259 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
260 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
261 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
262 |
+
"save_iterations": DEFAULT_NB_TRAINING_STEPS,
|
263 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
264 |
}
|
265 |
self.save_ui_state(default_state)
|
|
|
406 |
model_type: str,
|
407 |
lora_rank: str,
|
408 |
lora_alpha: str,
|
409 |
+
train_steps: int,
|
410 |
batch_size: int,
|
411 |
learning_rate: float,
|
412 |
save_iterations: int,
|
|
|
553 |
return error_msg, "Unsupported model"
|
554 |
|
555 |
# Update with UI parameters
|
556 |
+
config.train_steps = int(train_steps)
|
557 |
config.batch_size = int(batch_size)
|
558 |
config.lr = float(learning_rate)
|
559 |
config.checkpointing_steps = int(save_iterations)
|
|
|
575 |
|
576 |
# Common settings for both models
|
577 |
config.mixed_precision = "bf16"
|
578 |
+
config.seed = DEFAULT_SEED
|
579 |
config.gradient_checkpointing = True
|
580 |
config.enable_slicing = True
|
581 |
config.enable_tiling = True
|
582 |
+
config.caption_dropout_p = DEFAULT_CAPTION_DROPOUT_P
|
583 |
|
584 |
validation_error = self.validate_training_config(config, model_type)
|
585 |
if validation_error:
|
|
|
671 |
"training_type": training_type,
|
672 |
"lora_rank": lora_rank,
|
673 |
"lora_alpha": lora_alpha,
|
674 |
+
"train_steps": train_steps,
|
675 |
"batch_size": batch_size,
|
676 |
"learning_rate": learning_rate,
|
677 |
"save_iterations": save_iterations,
|
|
|
680 |
})
|
681 |
|
682 |
# Update initial training status
|
683 |
+
total_steps = int(train_steps)
|
684 |
self.save_status(
|
685 |
state='training',
|
|
|
686 |
step=0,
|
687 |
total_steps=total_steps,
|
688 |
loss=0.0,
|
|
|
689 |
message='Training started',
|
690 |
repo_id=repo_id,
|
691 |
model_type=model_type,
|
|
|
832 |
"params": {
|
833 |
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
834 |
"training_type": TRAINING_TYPES.get(ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])),
|
835 |
+
"lora_rank": ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR),
|
836 |
+
"lora_alpha": ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR),
|
837 |
+
"train_steps": ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
|
838 |
+
"batch_size": ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
|
839 |
+
"learning_rate": ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
|
840 |
+
"save_iterations": ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
841 |
"preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
842 |
"repo_id": "" # Default empty repo ID
|
843 |
}
|
|
|
896 |
ui_updates.update({
|
897 |
"model_type": model_type_display, # Use the display name for the UI dropdown
|
898 |
"training_type": training_type_display, # Use the display name for training type
|
899 |
+
"lora_rank": params.get('lora_rank', DEFAULT_LORA_RANK_STR),
|
900 |
+
"lora_alpha": params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR),
|
901 |
+
"train_steps": params.get('train_steps', DEFAULT_NB_TRAINING_STEPS),
|
902 |
+
"batch_size": params.get('batch_size', DEFAULT_BATCH_SIZE),
|
903 |
+
"learning_rate": params.get('learning_rate', DEFAULT_LEARNING_RATE),
|
904 |
+
"save_iterations": params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
905 |
"training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
906 |
})
|
907 |
|
|
|
915 |
# But keep model_type_display for the UI
|
916 |
result = self.start_training(
|
917 |
model_type=model_type_internal,
|
918 |
+
lora_rank=params.get('lora_rank', DEFAULT_LORA_RANK_STR),
|
919 |
+
lora_alpha=params.get('lora_alpha', DEFAULT_LORA_ALPHA_STR),
|
920 |
+
train_size=params.get('train_steps', DEFAULT_NB_TRAINING_STEPS),
|
921 |
+
batch_size=params.get('batch_size', DEFAULT_BATCH_SIZE),
|
922 |
+
learning_rate=params.get('learning_rate', DEFAULT_LEARNING_RATE),
|
923 |
+
save_iterations=params.get('save_iterations', DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
924 |
repo_id=params.get('repo_id', ''),
|
925 |
preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
|
926 |
training_type=training_type_internal,
|
vms/tabs/train_tab.py
CHANGED
@@ -9,7 +9,14 @@ from typing import Dict, Any, List, Optional, Tuple
|
|
9 |
from pathlib import Path
|
10 |
|
11 |
from .base_tab import BaseTab
|
12 |
-
from ..config import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
@@ -63,20 +70,20 @@ class TrainTab(BaseTab):
|
|
63 |
self.components["lora_rank"] = gr.Dropdown(
|
64 |
label="LoRA Rank",
|
65 |
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
66 |
-
value=
|
67 |
type="value"
|
68 |
)
|
69 |
self.components["lora_alpha"] = gr.Dropdown(
|
70 |
label="LoRA Alpha",
|
71 |
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
72 |
-
value=
|
73 |
type="value"
|
74 |
)
|
75 |
|
76 |
with gr.Row():
|
77 |
-
self.components["
|
78 |
-
label="Number of
|
79 |
-
value=
|
80 |
minimum=1,
|
81 |
precision=0
|
82 |
)
|
@@ -89,13 +96,13 @@ class TrainTab(BaseTab):
|
|
89 |
with gr.Row():
|
90 |
self.components["learning_rate"] = gr.Number(
|
91 |
label="Learning Rate",
|
92 |
-
value=
|
93 |
-
minimum=1e-
|
94 |
)
|
95 |
self.components["save_iterations"] = gr.Number(
|
96 |
label="Save checkpoint every N iterations",
|
97 |
-
value=
|
98 |
-
minimum=
|
99 |
precision=0,
|
100 |
info="Model will be saved periodically after these many steps"
|
101 |
)
|
@@ -170,7 +177,7 @@ class TrainTab(BaseTab):
|
|
170 |
|
171 |
return {
|
172 |
self.components["model_info"]: info,
|
173 |
-
self.components["
|
174 |
self.components["batch_size"]: params["batch_size"],
|
175 |
self.components["learning_rate"]: params["learning_rate"],
|
176 |
self.components["save_iterations"]: params["save_iterations"],
|
@@ -186,7 +193,7 @@ class TrainTab(BaseTab):
|
|
186 |
inputs=[self.components["model_type"], self.components["training_type"]],
|
187 |
outputs=[
|
188 |
self.components["model_info"],
|
189 |
-
self.components["
|
190 |
self.components["batch_size"],
|
191 |
self.components["learning_rate"],
|
192 |
self.components["save_iterations"],
|
@@ -204,7 +211,7 @@ class TrainTab(BaseTab):
|
|
204 |
inputs=[self.components["model_type"], self.components["training_type"]],
|
205 |
outputs=[
|
206 |
self.components["model_info"],
|
207 |
-
self.components["
|
208 |
self.components["batch_size"],
|
209 |
self.components["learning_rate"],
|
210 |
self.components["save_iterations"],
|
@@ -225,9 +232,9 @@ class TrainTab(BaseTab):
|
|
225 |
outputs=[]
|
226 |
)
|
227 |
|
228 |
-
self.components["
|
229 |
-
fn=lambda v: self.app.update_ui_state(
|
230 |
-
inputs=[self.components["
|
231 |
outputs=[]
|
232 |
)
|
233 |
|
@@ -262,7 +269,7 @@ class TrainTab(BaseTab):
|
|
262 |
self.components["training_type"],
|
263 |
self.components["lora_rank"],
|
264 |
self.components["lora_alpha"],
|
265 |
-
self.components["
|
266 |
self.components["batch_size"],
|
267 |
self.components["learning_rate"],
|
268 |
self.components["save_iterations"],
|
@@ -280,7 +287,7 @@ class TrainTab(BaseTab):
|
|
280 |
self.components["training_type"],
|
281 |
self.components["lora_rank"],
|
282 |
self.components["lora_alpha"],
|
283 |
-
self.components["
|
284 |
self.components["batch_size"],
|
285 |
self.components["learning_rate"],
|
286 |
self.components["save_iterations"],
|
@@ -290,27 +297,20 @@ class TrainTab(BaseTab):
|
|
290 |
self.components["status_box"],
|
291 |
self.components["log_box"]
|
292 |
]
|
293 |
-
).success(
|
294 |
-
fn=self.get_latest_status_message_logs_and_button_labels,
|
295 |
-
outputs=[
|
296 |
-
self.components["status_box"],
|
297 |
-
self.components["log_box"],
|
298 |
-
self.components["start_btn"],
|
299 |
-
self.components["stop_btn"],
|
300 |
-
self.components["pause_resume_btn"],
|
301 |
-
self.components["current_task_box"] # Include new component
|
302 |
-
]
|
303 |
)
|
304 |
|
|
|
|
|
|
|
305 |
self.components["pause_resume_btn"].click(
|
306 |
fn=self.handle_pause_resume,
|
307 |
outputs=[
|
308 |
self.components["status_box"],
|
309 |
self.components["log_box"],
|
|
|
310 |
self.components["start_btn"],
|
311 |
self.components["stop_btn"],
|
312 |
-
|
313 |
-
self.components["current_task_box"] # Include new component
|
314 |
]
|
315 |
)
|
316 |
|
@@ -319,10 +319,10 @@ class TrainTab(BaseTab):
|
|
319 |
outputs=[
|
320 |
self.components["status_box"],
|
321 |
self.components["log_box"],
|
|
|
322 |
self.components["start_btn"],
|
323 |
self.components["stop_btn"],
|
324 |
-
|
325 |
-
self.components["current_task_box"] # Include new component
|
326 |
]
|
327 |
)
|
328 |
|
@@ -330,16 +330,6 @@ class TrainTab(BaseTab):
|
|
330 |
self.components["delete_checkpoints_btn"].click(
|
331 |
fn=lambda: self.app.trainer.delete_all_checkpoints(),
|
332 |
outputs=[self.components["status_box"]]
|
333 |
-
).then(
|
334 |
-
fn=self.get_latest_status_message_logs_and_button_labels,
|
335 |
-
outputs=[
|
336 |
-
self.components["status_box"],
|
337 |
-
self.components["log_box"],
|
338 |
-
self.components["start_btn"],
|
339 |
-
self.components["stop_btn"],
|
340 |
-
self.components["delete_checkpoints_btn"],
|
341 |
-
self.components["current_task_box"] # Include new component
|
342 |
-
]
|
343 |
)
|
344 |
|
345 |
def handle_training_start(self, preset, model_type, training_type, *args):
|
@@ -391,7 +381,7 @@ class TrainTab(BaseTab):
|
|
391 |
|
392 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
393 |
"""Get information about the selected model type and training method"""
|
394 |
-
if model_type == "HunyuanVideo
|
395 |
base_info = """### HunyuanVideo
|
396 |
- Required VRAM: ~48GB minimum
|
397 |
- Recommended batch size: 1-2
|
@@ -403,7 +393,7 @@ class TrainTab(BaseTab):
|
|
403 |
else:
|
404 |
return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
|
405 |
|
406 |
-
elif model_type == "LTX-Video
|
407 |
base_info = """### LTX-Video
|
408 |
- Recommended batch size: 1-4
|
409 |
- Typical training time: 1-3 hours
|
@@ -414,14 +404,14 @@ class TrainTab(BaseTab):
|
|
414 |
else:
|
415 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
416 |
|
417 |
-
elif model_type == "Wan-2.1-T2V
|
418 |
base_info = """### Wan-2.1-T2V
|
419 |
-
- Recommended batch size:
|
420 |
-
- Typical training time:
|
421 |
- Default resolution: 49x512x768"""
|
422 |
|
423 |
if training_type == "LoRA Finetune":
|
424 |
-
return base_info + "\n- Required VRAM:
|
425 |
else:
|
426 |
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
|
427 |
|
@@ -440,51 +430,51 @@ class TrainTab(BaseTab):
|
|
440 |
# Use the first matching preset
|
441 |
preset = matching_presets[0]
|
442 |
return {
|
443 |
-
"
|
444 |
-
"batch_size": preset.get("batch_size",
|
445 |
-
"learning_rate": preset.get("learning_rate",
|
446 |
-
"save_iterations": preset.get("save_iterations",
|
447 |
-
"lora_rank": preset.get("lora_rank",
|
448 |
-
"lora_alpha": preset.get("lora_alpha",
|
449 |
}
|
450 |
|
451 |
# Default fallbacks
|
452 |
if model_type == "hunyuan_video":
|
453 |
return {
|
454 |
-
"
|
455 |
-
"batch_size":
|
456 |
"learning_rate": 2e-5,
|
457 |
-
"save_iterations":
|
458 |
-
"lora_rank":
|
459 |
-
"lora_alpha":
|
460 |
}
|
461 |
elif model_type == "ltx_video":
|
462 |
return {
|
463 |
-
"
|
464 |
-
"batch_size":
|
465 |
-
"learning_rate":
|
466 |
-
"save_iterations":
|
467 |
-
"lora_rank":
|
468 |
-
"lora_alpha":
|
469 |
}
|
470 |
elif model_type == "wan":
|
471 |
return {
|
472 |
-
"
|
473 |
-
"batch_size":
|
474 |
"learning_rate": 5e-5,
|
475 |
-
"save_iterations":
|
476 |
"lora_rank": "32",
|
477 |
"lora_alpha": "32"
|
478 |
}
|
479 |
else:
|
480 |
# Generic defaults
|
481 |
return {
|
482 |
-
"
|
483 |
-
"batch_size":
|
484 |
-
"learning_rate":
|
485 |
-
"save_iterations":
|
486 |
-
"lora_rank":
|
487 |
-
"lora_alpha":
|
488 |
}
|
489 |
|
490 |
def update_training_params(self, preset_name: str) -> Tuple:
|
@@ -522,12 +512,12 @@ class TrainTab(BaseTab):
|
|
522 |
show_lora_params = preset["training_type"] == "lora"
|
523 |
|
524 |
# Use preset defaults but preserve user-modified values if they exist
|
525 |
-
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank",
|
526 |
-
lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha",
|
527 |
-
|
528 |
-
batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size",
|
529 |
-
learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate",
|
530 |
-
save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations",
|
531 |
|
532 |
# Return values in the same order as the output components
|
533 |
return (
|
@@ -535,7 +525,7 @@ class TrainTab(BaseTab):
|
|
535 |
training_display_name,
|
536 |
lora_rank_val,
|
537 |
lora_alpha_val,
|
538 |
-
|
539 |
batch_size_val,
|
540 |
learning_rate_val,
|
541 |
save_iterations_val,
|
@@ -543,66 +533,6 @@ class TrainTab(BaseTab):
|
|
543 |
gr.Row(visible=show_lora_params)
|
544 |
)
|
545 |
|
546 |
-
def update_training_ui(self, training_state: Dict[str, Any]):
|
547 |
-
"""Update UI components based on training state"""
|
548 |
-
updates = {}
|
549 |
-
|
550 |
-
# Update status box with high-level information
|
551 |
-
status_text = []
|
552 |
-
if training_state["status"] != "idle":
|
553 |
-
status_text.extend([
|
554 |
-
f"Status: {training_state['status']}",
|
555 |
-
f"Progress: {training_state['progress']}",
|
556 |
-
f"Step: {training_state['current_step']}/{training_state['total_steps']}",
|
557 |
-
f"Time elapsed: {training_state['elapsed']}",
|
558 |
-
f"Estimated remaining: {training_state['remaining']}",
|
559 |
-
"",
|
560 |
-
f"Current loss: {training_state['step_loss']}",
|
561 |
-
f"Learning rate: {training_state['learning_rate']}",
|
562 |
-
f"Gradient norm: {training_state['grad_norm']}",
|
563 |
-
f"Memory usage: {training_state['memory']}"
|
564 |
-
])
|
565 |
-
|
566 |
-
if training_state["error_message"]:
|
567 |
-
status_text.append(f"\nError: {training_state['error_message']}")
|
568 |
-
|
569 |
-
updates["status_box"] = "\n".join(status_text)
|
570 |
-
|
571 |
-
# Add current task information to the dedicated box
|
572 |
-
if training_state.get("current_task"):
|
573 |
-
updates["current_task_box"] = training_state["current_task"]
|
574 |
-
else:
|
575 |
-
updates["current_task_box"] = "No active task" if training_state["status"] != "training" else "Waiting for task information..."
|
576 |
-
|
577 |
-
# Update button states
|
578 |
-
updates["start_btn"] = gr.Button(
|
579 |
-
"Start training",
|
580 |
-
interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]),
|
581 |
-
variant="primary" if training_state["status"] == "idle" else "secondary"
|
582 |
-
)
|
583 |
-
|
584 |
-
updates["stop_btn"] = gr.Button(
|
585 |
-
"Stop training",
|
586 |
-
interactive=(training_state["status"] in ["training", "initializing"]),
|
587 |
-
variant="stop"
|
588 |
-
)
|
589 |
-
|
590 |
-
return updates
|
591 |
-
|
592 |
-
def handle_pause_resume(self):
|
593 |
-
status, _, _ = self.get_latest_status_message_and_logs()
|
594 |
-
|
595 |
-
if status == "paused":
|
596 |
-
self.app.trainer.resume_training()
|
597 |
-
else:
|
598 |
-
self.app.trainer.pause_training()
|
599 |
-
|
600 |
-
return self.get_latest_status_message_logs_and_button_labels()
|
601 |
-
|
602 |
-
def handle_stop(self):
|
603 |
-
self.app.trainer.stop_training()
|
604 |
-
return self.get_latest_status_message_logs_and_button_labels()
|
605 |
-
|
606 |
def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
|
607 |
"""Get latest status message, log content, and status code in a safer way"""
|
608 |
state = self.app.trainer.get_status()
|
@@ -663,61 +593,107 @@ class TrainTab(BaseTab):
|
|
663 |
|
664 |
return (state["status"], state["message"], logs)
|
665 |
|
666 |
-
def
|
667 |
-
"""Get
|
668 |
status, message, logs = self.get_latest_status_message_and_logs()
|
669 |
|
670 |
-
# Add checkpoints detection
|
671 |
-
has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
|
672 |
-
|
673 |
-
button_updates = self.update_training_buttons(status, has_checkpoints).values()
|
674 |
-
|
675 |
# Get current task if available
|
676 |
current_task = ""
|
677 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
678 |
current_task = self.app.log_parser.get_current_task_display()
|
679 |
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
|
|
688 |
is_training = status in ["training", "initializing"]
|
689 |
is_completed = status in ["completed", "error", "stopped"]
|
690 |
|
691 |
start_text = "Continue Training" if has_checkpoints else "Start Training"
|
692 |
|
693 |
-
#
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
),
|
700 |
-
"stop_btn": gr.Button(
|
701 |
-
value="Stop at Last Checkpoint",
|
702 |
-
interactive=is_training,
|
703 |
-
variant="primary" if is_training else "secondary",
|
704 |
-
)
|
705 |
-
}
|
706 |
|
707 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
if "delete_checkpoints_btn" in self.components:
|
709 |
-
|
710 |
-
|
711 |
interactive=has_checkpoints and not is_training,
|
712 |
-
variant="stop"
|
713 |
)
|
714 |
else:
|
715 |
-
|
716 |
-
|
717 |
-
value="Resume Training" if status == "paused" else "Pause Training",
|
718 |
interactive=(is_training or status == "paused") and not is_completed,
|
719 |
variant="secondary",
|
720 |
visible=False
|
721 |
)
|
722 |
|
723 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from pathlib import Path
|
10 |
|
11 |
from .base_tab import BaseTab
|
12 |
+
from ..config import (
|
13 |
+
TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
|
14 |
+
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
15 |
+
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
16 |
+
DEFAULT_LEARNING_RATE,
|
17 |
+
DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
|
18 |
+
DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
|
19 |
+
)
|
20 |
|
21 |
logger = logging.getLogger(__name__)
|
22 |
|
|
|
70 |
self.components["lora_rank"] = gr.Dropdown(
|
71 |
label="LoRA Rank",
|
72 |
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
73 |
+
value=DEFAULT_LORA_RANK_STR,
|
74 |
type="value"
|
75 |
)
|
76 |
self.components["lora_alpha"] = gr.Dropdown(
|
77 |
label="LoRA Alpha",
|
78 |
choices=["16", "32", "64", "128", "256", "512", "1024"],
|
79 |
+
value=DEFAULT_LORA_ALPHA_STR,
|
80 |
type="value"
|
81 |
)
|
82 |
|
83 |
with gr.Row():
|
84 |
+
self.components["train_steps"] = gr.Number(
|
85 |
+
label="Number of Training Steps",
|
86 |
+
value=DEFAULT_NB_TRAINING_STEPS,
|
87 |
minimum=1,
|
88 |
precision=0
|
89 |
)
|
|
|
96 |
with gr.Row():
|
97 |
self.components["learning_rate"] = gr.Number(
|
98 |
label="Learning Rate",
|
99 |
+
value=DEFAULT_LEARNING_RATE,
|
100 |
+
minimum=1e-8
|
101 |
)
|
102 |
self.components["save_iterations"] = gr.Number(
|
103 |
label="Save checkpoint every N iterations",
|
104 |
+
value=DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
105 |
+
minimum=1,
|
106 |
precision=0,
|
107 |
info="Model will be saved periodically after these many steps"
|
108 |
)
|
|
|
177 |
|
178 |
return {
|
179 |
self.components["model_info"]: info,
|
180 |
+
self.components["train_steps"]: params["train_steps"],
|
181 |
self.components["batch_size"]: params["batch_size"],
|
182 |
self.components["learning_rate"]: params["learning_rate"],
|
183 |
self.components["save_iterations"]: params["save_iterations"],
|
|
|
193 |
inputs=[self.components["model_type"], self.components["training_type"]],
|
194 |
outputs=[
|
195 |
self.components["model_info"],
|
196 |
+
self.components["train_steps"],
|
197 |
self.components["batch_size"],
|
198 |
self.components["learning_rate"],
|
199 |
self.components["save_iterations"],
|
|
|
211 |
inputs=[self.components["model_type"], self.components["training_type"]],
|
212 |
outputs=[
|
213 |
self.components["model_info"],
|
214 |
+
self.components["train_steps"],
|
215 |
self.components["batch_size"],
|
216 |
self.components["learning_rate"],
|
217 |
self.components["save_iterations"],
|
|
|
232 |
outputs=[]
|
233 |
)
|
234 |
|
235 |
+
self.components["train_steps"].change(
|
236 |
+
fn=lambda v: self.app.update_ui_state(train_steps=v),
|
237 |
+
inputs=[self.components["train_steps"]],
|
238 |
outputs=[]
|
239 |
)
|
240 |
|
|
|
269 |
self.components["training_type"],
|
270 |
self.components["lora_rank"],
|
271 |
self.components["lora_alpha"],
|
272 |
+
self.components["train_steps"],
|
273 |
self.components["batch_size"],
|
274 |
self.components["learning_rate"],
|
275 |
self.components["save_iterations"],
|
|
|
287 |
self.components["training_type"],
|
288 |
self.components["lora_rank"],
|
289 |
self.components["lora_alpha"],
|
290 |
+
self.components["train_steps"],
|
291 |
self.components["batch_size"],
|
292 |
self.components["learning_rate"],
|
293 |
self.components["save_iterations"],
|
|
|
297 |
self.components["status_box"],
|
298 |
self.components["log_box"]
|
299 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
)
|
301 |
|
302 |
+
# Use simplified event handlers for pause/resume and stop
|
303 |
+
third_btn = self.components["delete_checkpoints_btn"] if "delete_checkpoints_btn" in self.components else self.components["pause_resume_btn"]
|
304 |
+
|
305 |
self.components["pause_resume_btn"].click(
|
306 |
fn=self.handle_pause_resume,
|
307 |
outputs=[
|
308 |
self.components["status_box"],
|
309 |
self.components["log_box"],
|
310 |
+
self.components["current_task_box"],
|
311 |
self.components["start_btn"],
|
312 |
self.components["stop_btn"],
|
313 |
+
third_btn
|
|
|
314 |
]
|
315 |
)
|
316 |
|
|
|
319 |
outputs=[
|
320 |
self.components["status_box"],
|
321 |
self.components["log_box"],
|
322 |
+
self.components["current_task_box"],
|
323 |
self.components["start_btn"],
|
324 |
self.components["stop_btn"],
|
325 |
+
third_btn
|
|
|
326 |
]
|
327 |
)
|
328 |
|
|
|
330 |
self.components["delete_checkpoints_btn"].click(
|
331 |
fn=lambda: self.app.trainer.delete_all_checkpoints(),
|
332 |
outputs=[self.components["status_box"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
)
|
334 |
|
335 |
def handle_training_start(self, preset, model_type, training_type, *args):
|
|
|
381 |
|
382 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
383 |
"""Get information about the selected model type and training method"""
|
384 |
+
if model_type == "HunyuanVideo":
|
385 |
base_info = """### HunyuanVideo
|
386 |
- Required VRAM: ~48GB minimum
|
387 |
- Recommended batch size: 1-2
|
|
|
393 |
else:
|
394 |
return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
|
395 |
|
396 |
+
elif model_type == "LTX-Video":
|
397 |
base_info = """### LTX-Video
|
398 |
- Recommended batch size: 1-4
|
399 |
- Typical training time: 1-3 hours
|
|
|
404 |
else:
|
405 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
406 |
|
407 |
+
elif model_type == "Wan-2.1-T2V":
|
408 |
base_info = """### Wan-2.1-T2V
|
409 |
+
- Recommended batch size: ?
|
410 |
+
- Typical training time: ? hours
|
411 |
- Default resolution: 49x512x768"""
|
412 |
|
413 |
if training_type == "LoRA Finetune":
|
414 |
+
return base_info + "\n- Required VRAM: ?GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
415 |
else:
|
416 |
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
|
417 |
|
|
|
430 |
# Use the first matching preset
|
431 |
preset = matching_presets[0]
|
432 |
return {
|
433 |
+
"train_steps": preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
|
434 |
+
"batch_size": preset.get("batch_size", DEFAULT_BATCH_SIZE),
|
435 |
+
"learning_rate": preset.get("learning_rate", DEFAULT_LEARNING_RATE),
|
436 |
+
"save_iterations": preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS),
|
437 |
+
"lora_rank": preset.get("lora_rank", DEFAULT_LORA_RANK_STR),
|
438 |
+
"lora_alpha": preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
|
439 |
}
|
440 |
|
441 |
# Default fallbacks
|
442 |
if model_type == "hunyuan_video":
|
443 |
return {
|
444 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
445 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
446 |
"learning_rate": 2e-5,
|
447 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
448 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
449 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
450 |
}
|
451 |
elif model_type == "ltx_video":
|
452 |
return {
|
453 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
454 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
455 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
456 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
457 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
458 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
459 |
}
|
460 |
elif model_type == "wan":
|
461 |
return {
|
462 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
463 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
464 |
"learning_rate": 5e-5,
|
465 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
466 |
"lora_rank": "32",
|
467 |
"lora_alpha": "32"
|
468 |
}
|
469 |
else:
|
470 |
# Generic defaults
|
471 |
return {
|
472 |
+
"train_steps": DEFAULT_NB_TRAINING_STEPS,
|
473 |
+
"batch_size": DEFAULT_BATCH_SIZE,
|
474 |
+
"learning_rate": DEFAULT_LEARNING_RATE,
|
475 |
+
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
476 |
+
"lora_rank": DEFAULT_LORA_RANK_STR,
|
477 |
+
"lora_alpha": DEFAULT_LORA_ALPHA_STR
|
478 |
}
|
479 |
|
480 |
def update_training_params(self, preset_name: str) -> Tuple:
|
|
|
512 |
show_lora_params = preset["training_type"] == "lora"
|
513 |
|
514 |
# Use preset defaults but preserve user-modified values if they exist
|
515 |
+
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", DEFAULT_LORA_RANK_STR) else preset.get("lora_rank", DEFAULT_LORA_RANK_STR)
|
516 |
+
lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR) else preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
|
517 |
+
train_steps_val = current_state.get("train_steps") if current_state.get("train_steps") != preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS) else preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS)
|
518 |
+
batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", DEFAULT_BATCH_SIZE) else preset.get("batch_size", DEFAULT_BATCH_SIZE)
|
519 |
+
learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", DEFAULT_LEARNING_RATE) else preset.get("learning_rate", DEFAULT_LEARNING_RATE)
|
520 |
+
save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS) else preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)
|
521 |
|
522 |
# Return values in the same order as the output components
|
523 |
return (
|
|
|
525 |
training_display_name,
|
526 |
lora_rank_val,
|
527 |
lora_alpha_val,
|
528 |
+
train_steps_val,
|
529 |
batch_size_val,
|
530 |
learning_rate_val,
|
531 |
save_iterations_val,
|
|
|
533 |
gr.Row(visible=show_lora_params)
|
534 |
)
|
535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
|
537 |
"""Get latest status message, log content, and status code in a safer way"""
|
538 |
state = self.app.trainer.get_status()
|
|
|
593 |
|
594 |
return (state["status"], state["message"], logs)
|
595 |
|
596 |
+
def get_status_updates(self):
|
597 |
+
"""Get status updates for text components (no variant property)"""
|
598 |
status, message, logs = self.get_latest_status_message_and_logs()
|
599 |
|
|
|
|
|
|
|
|
|
|
|
600 |
# Get current task if available
|
601 |
current_task = ""
|
602 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
603 |
current_task = self.app.log_parser.get_current_task_display()
|
604 |
|
605 |
+
return message, logs, current_task
|
606 |
+
|
607 |
+
def get_button_updates(self):
|
608 |
+
"""Get button updates (with variant property)"""
|
609 |
+
status, _, _ = self.get_latest_status_message_and_logs()
|
610 |
+
|
611 |
+
# Add checkpoints detection
|
612 |
+
has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
|
613 |
+
|
614 |
is_training = status in ["training", "initializing"]
|
615 |
is_completed = status in ["completed", "error", "stopped"]
|
616 |
|
617 |
start_text = "Continue Training" if has_checkpoints else "Start Training"
|
618 |
|
619 |
+
# Create button updates
|
620 |
+
start_btn = gr.Button(
|
621 |
+
value=start_text,
|
622 |
+
interactive=not is_training,
|
623 |
+
variant="primary" if not is_training else "secondary"
|
624 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
625 |
|
626 |
+
stop_btn = gr.Button(
|
627 |
+
value="Stop at Last Checkpoint",
|
628 |
+
interactive=is_training,
|
629 |
+
variant="primary" if is_training else "secondary"
|
630 |
+
)
|
631 |
+
|
632 |
+
# Add delete_checkpoints_btn or pause_resume_btn
|
633 |
if "delete_checkpoints_btn" in self.components:
|
634 |
+
third_btn = gr.Button(
|
635 |
+
"Delete All Checkpoints",
|
636 |
interactive=has_checkpoints and not is_training,
|
637 |
+
variant="stop"
|
638 |
)
|
639 |
else:
|
640 |
+
third_btn = gr.Button(
|
641 |
+
"Resume Training" if status == "paused" else "Pause Training",
|
|
|
642 |
interactive=(is_training or status == "paused") and not is_completed,
|
643 |
variant="secondary",
|
644 |
visible=False
|
645 |
)
|
646 |
|
647 |
+
return start_btn, stop_btn, third_btn
|
648 |
+
|
649 |
+
def update_training_ui(self, training_state: Dict[str, Any]):
|
650 |
+
"""Update UI components based on training state"""
|
651 |
+
updates = {}
|
652 |
+
|
653 |
+
# Update status box with high-level information
|
654 |
+
status_text = []
|
655 |
+
if training_state["status"] != "idle":
|
656 |
+
status_text.extend([
|
657 |
+
f"Status: {training_state['status']}",
|
658 |
+
f"Progress: {training_state['progress']}",
|
659 |
+
f"Step: {training_state['current_step']}/{training_state['total_steps']}",
|
660 |
+
f"Time elapsed: {training_state['elapsed']}",
|
661 |
+
f"Estimated remaining: {training_state['remaining']}",
|
662 |
+
"",
|
663 |
+
f"Current loss: {training_state['step_loss']}",
|
664 |
+
f"Learning rate: {training_state['learning_rate']}",
|
665 |
+
f"Gradient norm: {training_state['grad_norm']}",
|
666 |
+
f"Memory usage: {training_state['memory']}"
|
667 |
+
])
|
668 |
+
|
669 |
+
if training_state["error_message"]:
|
670 |
+
status_text.append(f"\nError: {training_state['error_message']}")
|
671 |
+
|
672 |
+
updates["status_box"] = "\n".join(status_text)
|
673 |
+
|
674 |
+
# Add current task information to the dedicated box
|
675 |
+
if training_state.get("current_task"):
|
676 |
+
updates["current_task_box"] = training_state["current_task"]
|
677 |
+
else:
|
678 |
+
updates["current_task_box"] = "No active task" if training_state["status"] != "training" else "Waiting for task information..."
|
679 |
+
|
680 |
+
return updates
|
681 |
+
|
682 |
+
def handle_pause_resume(self):
|
683 |
+
"""Handle pause/resume button click"""
|
684 |
+
status, _, _ = self.get_latest_status_message_and_logs()
|
685 |
+
|
686 |
+
if status == "paused":
|
687 |
+
self.app.trainer.resume_training()
|
688 |
+
else:
|
689 |
+
self.app.trainer.pause_training()
|
690 |
+
|
691 |
+
# Return the updates separately for text and buttons
|
692 |
+
return (*self.get_status_updates(), *self.get_button_updates())
|
693 |
+
|
694 |
+
def handle_stop(self):
|
695 |
+
"""Handle stop button click"""
|
696 |
+
self.app.trainer.stop_training()
|
697 |
+
|
698 |
+
# Return the updates separately for text and buttons
|
699 |
+
return (*self.get_status_updates(), *self.get_button_updates())
|
vms/ui/video_trainer_ui.py
CHANGED
@@ -9,7 +9,12 @@ from ..services import TrainingService, CaptioningService, SplittingService, Imp
|
|
9 |
from ..config import (
|
10 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
|
11 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
12 |
-
MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES
|
|
|
|
|
|
|
|
|
|
|
13 |
)
|
14 |
from ..utils import count_media_files, format_media_title, TrainingLogParser
|
15 |
from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
|
@@ -92,7 +97,7 @@ class VideoTrainerUI:
|
|
92 |
self.tabs["train_tab"].components["training_type"],
|
93 |
self.tabs["train_tab"].components["lora_rank"],
|
94 |
self.tabs["train_tab"].components["lora_alpha"],
|
95 |
-
self.tabs["train_tab"].components["
|
96 |
self.tabs["train_tab"].components["batch_size"],
|
97 |
self.tabs["train_tab"].components["learning_rate"],
|
98 |
self.tabs["train_tab"].components["save_iterations"],
|
@@ -104,31 +109,33 @@ class VideoTrainerUI:
|
|
104 |
|
105 |
def _add_timers(self):
|
106 |
"""Add auto-refresh timers to the UI"""
|
107 |
-
# Status update timer (every 1 second)
|
108 |
status_timer = gr.Timer(value=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
-
#
|
111 |
-
|
112 |
-
|
113 |
-
self.tabs["train_tab"].components["log_box"],
|
114 |
self.tabs["train_tab"].components["start_btn"],
|
115 |
self.tabs["train_tab"].components["stop_btn"]
|
116 |
]
|
117 |
|
118 |
-
# Add
|
119 |
-
if "current_task_box" in self.tabs["train_tab"].components:
|
120 |
-
outputs.append(self.tabs["train_tab"].components["current_task_box"])
|
121 |
-
|
122 |
-
# Add delete_checkpoints_btn only if it exists
|
123 |
if "delete_checkpoints_btn" in self.tabs["train_tab"].components:
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
outputs.append(self.tabs["train_tab"].components["pause_resume_btn"])
|
128 |
|
129 |
-
|
130 |
-
fn=self.tabs["train_tab"].
|
131 |
-
outputs=
|
132 |
)
|
133 |
|
134 |
# Dataset refresh timer (every 5 seconds)
|
@@ -175,6 +182,11 @@ class VideoTrainerUI:
|
|
175 |
if "model_type" in recovery_ui:
|
176 |
model_type_value = recovery_ui["model_type"]
|
177 |
|
|
|
|
|
|
|
|
|
|
|
178 |
# If it's an internal name, convert to display name
|
179 |
if model_type_value not in MODEL_TYPES:
|
180 |
# Find the display name for this internal model type
|
@@ -201,7 +213,7 @@ class VideoTrainerUI:
|
|
201 |
ui_state["training_type"] = training_type_value
|
202 |
|
203 |
# Copy other parameters
|
204 |
-
for param in ["lora_rank", "lora_alpha", "
|
205 |
"batch_size", "learning_rate", "save_iterations", "training_preset"]:
|
206 |
if param in recovery_ui:
|
207 |
ui_state[param] = recovery_ui[param]
|
@@ -216,31 +228,55 @@ class VideoTrainerUI:
|
|
216 |
# Load values (potentially with recovery updates applied)
|
217 |
ui_state = self.load_ui_values()
|
218 |
|
219 |
-
# Ensure model_type is a display name
|
220 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
if model_type_val not in MODEL_TYPES:
|
222 |
-
# Convert from internal to display name
|
|
|
223 |
for display_name, internal_name in MODEL_TYPES.items():
|
224 |
if internal_name == model_type_val:
|
225 |
model_type_val = display_name
|
|
|
226 |
break
|
|
|
|
|
|
|
|
|
227 |
|
228 |
-
# Ensure training_type is a display name
|
229 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
230 |
if training_type_val not in TRAINING_TYPES:
|
231 |
-
# Convert from internal to display name
|
|
|
232 |
for display_name, internal_name in TRAINING_TYPES.items():
|
233 |
if internal_name == training_type_val:
|
234 |
training_type_val = display_name
|
|
|
235 |
break
|
|
|
|
|
|
|
|
|
236 |
|
|
|
237 |
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
# Initial current task value
|
246 |
current_task_val = ""
|
@@ -259,7 +295,7 @@ class VideoTrainerUI:
|
|
259 |
training_type_val,
|
260 |
lora_rank_val,
|
261 |
lora_alpha_val,
|
262 |
-
|
263 |
batch_size_val,
|
264 |
learning_rate_val,
|
265 |
save_iterations_val,
|
@@ -275,12 +311,12 @@ class VideoTrainerUI:
|
|
275 |
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
276 |
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
277 |
ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
|
278 |
-
ui_state.get("lora_rank",
|
279 |
-
ui_state.get("lora_alpha",
|
280 |
-
ui_state.get("
|
281 |
-
ui_state.get("batch_size",
|
282 |
-
ui_state.get("learning_rate",
|
283 |
-
ui_state.get("save_iterations",
|
284 |
)
|
285 |
|
286 |
def update_ui_state(self, **kwargs):
|
@@ -296,12 +332,12 @@ class VideoTrainerUI:
|
|
296 |
ui_state = self.trainer.load_ui_state()
|
297 |
|
298 |
# Ensure proper type conversion for numeric values
|
299 |
-
ui_state["lora_rank"] = ui_state.get("lora_rank",
|
300 |
-
ui_state["lora_alpha"] = ui_state.get("lora_alpha",
|
301 |
-
ui_state["
|
302 |
-
ui_state["batch_size"] = int(ui_state.get("batch_size",
|
303 |
-
ui_state["learning_rate"] = float(ui_state.get("learning_rate",
|
304 |
-
ui_state["save_iterations"] = int(ui_state.get("save_iterations",
|
305 |
|
306 |
return ui_state
|
307 |
|
|
|
9 |
from ..config import (
|
10 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
|
11 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
12 |
+
MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
|
13 |
+
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
14 |
+
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
15 |
+
DEFAULT_LEARNING_RATE,
|
16 |
+
DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA,
|
17 |
+
DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR
|
18 |
)
|
19 |
from ..utils import count_media_files, format_media_title, TrainingLogParser
|
20 |
from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
|
|
|
97 |
self.tabs["train_tab"].components["training_type"],
|
98 |
self.tabs["train_tab"].components["lora_rank"],
|
99 |
self.tabs["train_tab"].components["lora_alpha"],
|
100 |
+
self.tabs["train_tab"].components["train_steps"],
|
101 |
self.tabs["train_tab"].components["batch_size"],
|
102 |
self.tabs["train_tab"].components["learning_rate"],
|
103 |
self.tabs["train_tab"].components["save_iterations"],
|
|
|
109 |
|
110 |
def _add_timers(self):
|
111 |
"""Add auto-refresh timers to the UI"""
|
112 |
+
# Status update timer for text components (every 1 second)
|
113 |
status_timer = gr.Timer(value=1)
|
114 |
+
status_timer.tick(
|
115 |
+
fn=self.tabs["train_tab"].get_status_updates, # Use a new function that returns appropriate updates
|
116 |
+
outputs=[
|
117 |
+
self.tabs["train_tab"].components["status_box"],
|
118 |
+
self.tabs["train_tab"].components["log_box"],
|
119 |
+
self.tabs["train_tab"].components["current_task_box"] if "current_task_box" in self.tabs["train_tab"].components else None
|
120 |
+
]
|
121 |
+
)
|
122 |
|
123 |
+
# Button update timer for button components (every 1 second)
|
124 |
+
button_timer = gr.Timer(value=1)
|
125 |
+
button_outputs = [
|
|
|
126 |
self.tabs["train_tab"].components["start_btn"],
|
127 |
self.tabs["train_tab"].components["stop_btn"]
|
128 |
]
|
129 |
|
130 |
+
# Add delete_checkpoints_btn or pause_resume_btn as the third button
|
|
|
|
|
|
|
|
|
131 |
if "delete_checkpoints_btn" in self.tabs["train_tab"].components:
|
132 |
+
button_outputs.append(self.tabs["train_tab"].components["delete_checkpoints_btn"])
|
133 |
+
elif "pause_resume_btn" in self.tabs["train_tab"].components:
|
134 |
+
button_outputs.append(self.tabs["train_tab"].components["pause_resume_btn"])
|
|
|
135 |
|
136 |
+
button_timer.tick(
|
137 |
+
fn=self.tabs["train_tab"].get_button_updates, # Use a new function for button-specific updates
|
138 |
+
outputs=button_outputs
|
139 |
)
|
140 |
|
141 |
# Dataset refresh timer (every 5 seconds)
|
|
|
182 |
if "model_type" in recovery_ui:
|
183 |
model_type_value = recovery_ui["model_type"]
|
184 |
|
185 |
+
# Remove " (LoRA)" suffix if present
|
186 |
+
if " (LoRA)" in model_type_value:
|
187 |
+
model_type_value = model_type_value.replace(" (LoRA)", "")
|
188 |
+
logger.info(f"Removed (LoRA) suffix from model type: {model_type_value}")
|
189 |
+
|
190 |
# If it's an internal name, convert to display name
|
191 |
if model_type_value not in MODEL_TYPES:
|
192 |
# Find the display name for this internal model type
|
|
|
213 |
ui_state["training_type"] = training_type_value
|
214 |
|
215 |
# Copy other parameters
|
216 |
+
for param in ["lora_rank", "lora_alpha", "train_steps",
|
217 |
"batch_size", "learning_rate", "save_iterations", "training_preset"]:
|
218 |
if param in recovery_ui:
|
219 |
ui_state[param] = recovery_ui[param]
|
|
|
228 |
# Load values (potentially with recovery updates applied)
|
229 |
ui_state = self.load_ui_values()
|
230 |
|
231 |
+
# Ensure model_type is a valid display name
|
232 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
233 |
+
# Remove " (LoRA)" suffix if present
|
234 |
+
if " (LoRA)" in model_type_val:
|
235 |
+
model_type_val = model_type_val.replace(" (LoRA)", "")
|
236 |
+
logger.info(f"Removed (LoRA) suffix from model type: {model_type_val}")
|
237 |
+
|
238 |
+
# Ensure it's a valid model type in the dropdown
|
239 |
if model_type_val not in MODEL_TYPES:
|
240 |
+
# Convert from internal to display name or use default
|
241 |
+
model_type_found = False
|
242 |
for display_name, internal_name in MODEL_TYPES.items():
|
243 |
if internal_name == model_type_val:
|
244 |
model_type_val = display_name
|
245 |
+
model_type_found = True
|
246 |
break
|
247 |
+
# If still not found, use the first model type
|
248 |
+
if not model_type_found:
|
249 |
+
model_type_val = list(MODEL_TYPES.keys())[0]
|
250 |
+
logger.warning(f"Invalid model type '{model_type_val}', using default: {model_type_val}")
|
251 |
|
252 |
+
# Ensure training_type is a valid display name
|
253 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
254 |
if training_type_val not in TRAINING_TYPES:
|
255 |
+
# Convert from internal to display name or use default
|
256 |
+
training_type_found = False
|
257 |
for display_name, internal_name in TRAINING_TYPES.items():
|
258 |
if internal_name == training_type_val:
|
259 |
training_type_val = display_name
|
260 |
+
training_type_found = True
|
261 |
break
|
262 |
+
# If still not found, use the first training type
|
263 |
+
if not training_type_found:
|
264 |
+
training_type_val = list(TRAINING_TYPES.keys())[0]
|
265 |
+
logger.warning(f"Invalid training type '{training_type_val}', using default: {training_type_val}")
|
266 |
|
267 |
+
# Validate training preset
|
268 |
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
269 |
+
if training_preset not in TRAINING_PRESETS:
|
270 |
+
training_preset = list(TRAINING_PRESETS.keys())[0]
|
271 |
+
logger.warning(f"Invalid training preset '{training_preset}', using default: {training_preset}")
|
272 |
+
|
273 |
+
# Rest of the function remains unchanged
|
274 |
+
lora_rank_val = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
|
275 |
+
lora_alpha_val = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
|
276 |
+
train_steps_val = int(ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS))
|
277 |
+
batch_size_val = int(ui_state.get("batch_size", DEFAULT_BATCH_SIZE))
|
278 |
+
learning_rate_val = float(ui_state.get("learning_rate", DEFAULT_LEARNING_RATE))
|
279 |
+
save_iterations_val = int(ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS))
|
280 |
|
281 |
# Initial current task value
|
282 |
current_task_val = ""
|
|
|
295 |
training_type_val,
|
296 |
lora_rank_val,
|
297 |
lora_alpha_val,
|
298 |
+
train_steps_val,
|
299 |
batch_size_val,
|
300 |
learning_rate_val,
|
301 |
save_iterations_val,
|
|
|
311 |
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
312 |
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
313 |
ui_state.get("training_type", list(TRAINING_TYPES.keys())[0]),
|
314 |
+
ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR),
|
315 |
+
ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR),
|
316 |
+
ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS),
|
317 |
+
ui_state.get("batch_size", DEFAULT_BATCH_SIZE),
|
318 |
+
ui_state.get("learning_rate", DEFAULT_LEARNING_RATE),
|
319 |
+
ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS)
|
320 |
)
|
321 |
|
322 |
def update_ui_state(self, **kwargs):
|
|
|
332 |
ui_state = self.trainer.load_ui_state()
|
333 |
|
334 |
# Ensure proper type conversion for numeric values
|
335 |
+
ui_state["lora_rank"] = ui_state.get("lora_rank", DEFAULT_LORA_RANK_STR)
|
336 |
+
ui_state["lora_alpha"] = ui_state.get("lora_alpha", DEFAULT_LORA_ALPHA_STR)
|
337 |
+
ui_state["train_steps"] = int(ui_state.get("train_steps", DEFAULT_NB_TRAINING_STEPS))
|
338 |
+
ui_state["batch_size"] = int(ui_state.get("batch_size", DEFAULT_BATCH_SIZE))
|
339 |
+
ui_state["learning_rate"] = float(ui_state.get("learning_rate", DEFAULT_LEARNING_RATE))
|
340 |
+
ui_state["save_iterations"] = int(ui_state.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS))
|
341 |
|
342 |
return ui_state
|
343 |
|