jbilcke-hf HF Staff commited on
Commit
c90af3c
·
1 Parent(s): a529bb7

working on UI improvements

Browse files
Files changed (3) hide show
  1. app.py +43 -2
  2. config.py +98 -24
  3. training_service.py +11 -4
app.py CHANGED
@@ -661,6 +661,26 @@ class VideoTrainerUI:
661
  training_dataset
662
  )
663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
664
  def create_ui(self):
665
  """Create Gradio interface"""
666
 
@@ -820,6 +840,15 @@ class VideoTrainerUI:
820
  with gr.Row():
821
  train_title = gr.Markdown("## 0 files available for training (0 bytes)")
822
 
 
 
 
 
 
 
 
 
 
823
  with gr.Row():
824
  with gr.Column():
825
  model_type = gr.Dropdown(
@@ -1096,16 +1125,28 @@ class VideoTrainerUI:
1096
  outputs=[training_dataset]
1097
  )
1098
 
 
 
 
 
 
 
 
 
 
 
1099
  # Training control events
1100
  start_btn.click(
1101
- fn=lambda model_type, *args: (
1102
  self.log_parser.reset(),
1103
  self.trainer.start_training(
1104
  MODEL_TYPES[model_type],
1105
- *args
 
1106
  )
1107
  ),
1108
  inputs=[
 
1109
  model_type,
1110
  lora_rank,
1111
  lora_alpha,
 
661
  training_dataset
662
  )
663
 
664
+ def update_training_params(self, preset_name: str) -> Dict:
665
+ """Update UI components based on selected preset"""
666
+ preset = TRAINING_PRESETS[preset_name]
667
+
668
+ # Get preset description for display
669
+ description = preset.get("description", "")
670
+ bucket_info = f"\nBucket configuration: {len(preset['training_buckets'])} buckets"
671
+ info_text = f"{description}{bucket_info}"
672
+
673
+ return {
674
+ "model_type": gr.Dropdown(value=MODEL_TYPES[preset["model_type"]]),
675
+ "lora_rank": gr.Dropdown(value=preset["lora_rank"]),
676
+ "lora_alpha": gr.Dropdown(value=preset["lora_alpha"]),
677
+ "num_epochs": gr.Number(value=preset["num_epochs"]),
678
+ "batch_size": gr.Number(value=preset["batch_size"]),
679
+ "learning_rate": gr.Number(value=preset["learning_rate"]),
680
+ "save_iterations": gr.Number(value=preset["save_iterations"]),
681
+ "preset_info": gr.Markdown(value=info_text)
682
+ }
683
+
684
  def create_ui(self):
685
  """Create Gradio interface"""
686
 
 
840
  with gr.Row():
841
  train_title = gr.Markdown("## 0 files available for training (0 bytes)")
842
 
843
+ with gr.Row():
844
+ with gr.Column():
845
+ training_preset = gr.Dropdown(
846
+ choices=list(TRAINING_PRESETS.keys()),
847
+ label="Training Preset",
848
+ value=list(TRAINING_PRESETS.keys())[0]
849
+ )
850
+ preset_info = gr.Markdown()
851
+
852
  with gr.Row():
853
  with gr.Column():
854
  model_type = gr.Dropdown(
 
1125
  outputs=[training_dataset]
1126
  )
1127
 
1128
+ training_preset.change(
1129
+ fn=self.update_training_params,
1130
+ inputs=[training_preset],
1131
+ outputs=[
1132
+ model_type, lora_rank, lora_alpha,
1133
+ num_epochs, batch_size, learning_rate,
1134
+ save_iterations, preset_info
1135
+ ]
1136
+ )
1137
+
1138
  # Training control events
1139
  start_btn.click(
1140
+ fn=lambda preset, model_type, *args: (
1141
  self.log_parser.reset(),
1142
  self.trainer.start_training(
1143
  MODEL_TYPES[model_type],
1144
+ *args,
1145
+ preset_name=preset
1146
  )
1147
  ),
1148
  inputs=[
1149
+ training_preset,
1150
  model_type,
1151
  lora_rank,
1152
  lora_alpha,
config.py CHANGED
@@ -55,8 +55,8 @@ MODEL_TYPES = {
55
  # it is best to use resolutions that are powers of 8
56
  # The resolution should be divisible by 32
57
  # so we cannot use 1080, 540 etc as they are not divisible by 32
58
- TRAINING_WIDTH = 768 # 32 * 24
59
- TRAINING_HEIGHT = 512 # 32 * 16
60
 
61
  # 1920 = 32 * 60 (divided by 2: 960 = 32 * 30)
62
  # 1920 = 32 * 60 (divided by 2: 960 = 32 * 30)
@@ -65,26 +65,100 @@ TRAINING_HEIGHT = 512 # 32 * 16
65
  # it is important that the resolution buckets properly cover the training dataset,
66
  # or else that we exclude from the dataset videos that are out of this range
67
  # right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
68
-
69
- TRAINING_BUCKETS = [
70
- (1, TRAINING_HEIGHT, TRAINING_WIDTH), # 1
71
- (8 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 8 + 1
72
- (8 * 2 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 16 + 1
73
- (8 * 4 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 32 + 1
74
- (8 * 6 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 48 + 1
75
- (8 * 8 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 64 + 1
76
- (8 * 10 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 80 + 1
77
- (8 * 12 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 96 + 1
78
- (8 * 14 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 112 + 1
79
- (8 * 16 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 128 + 1
80
- (8 * 18 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 144 + 1
81
- (8 * 20 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 160 + 1
82
- (8 * 22 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 176 + 1
83
- (8 * 24 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 192 + 1
84
- (8 * 28 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 224 + 1
85
- (8 * 32 + 1, TRAINING_HEIGHT, TRAINING_WIDTH), # 256 + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  ]
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  @dataclass
89
  class TrainingConfig:
90
  """Configuration class for finetrainers training"""
@@ -159,7 +233,7 @@ class TrainingConfig:
159
  nccl_timeout: int = 1800
160
 
161
  @classmethod
162
- def hunyuan_video_lora(cls, data_path: str, output_path: str) -> 'TrainingConfig':
163
  """Configuration for Hunyuan video-to-video LoRA training"""
164
  return cls(
165
  model_name="hunyuan_video",
@@ -174,13 +248,13 @@ class TrainingConfig:
174
  gradient_accumulation_steps=1,
175
  lora_rank=128,
176
  lora_alpha=128,
177
- video_resolution_buckets=TRAINING_BUCKETS,
178
  caption_dropout_p=0.05,
179
  flow_weighting_scheme="none" # Hunyuan specific
180
  )
181
 
182
  @classmethod
183
- def ltx_video_lora(cls, data_path: str, output_path: str) -> 'TrainingConfig':
184
  """Configuration for LTX-Video LoRA training"""
185
  return cls(
186
  model_name="ltx_video",
@@ -195,7 +269,7 @@ class TrainingConfig:
195
  gradient_accumulation_steps=4,
196
  lora_rank=128,
197
  lora_alpha=128,
198
- video_resolution_buckets=TRAINING_BUCKETS,
199
  caption_dropout_p=0.05,
200
  flow_weighting_scheme="logit_normal" # LTX specific
201
  )
 
55
  # it is best to use resolutions that are powers of 8
56
  # The resolution should be divisible by 32
57
  # so we cannot use 1080, 540 etc as they are not divisible by 32
58
+ MEDIUM_19_9_RATIO_WIDTH = 768 # 32 * 24
59
+ MEDIUM_19_9_RATIO_HEIGHT = 512 # 32 * 16
60
 
61
  # 1920 = 32 * 60 (divided by 2: 960 = 32 * 30)
62
  # 1920 = 32 * 60 (divided by 2: 960 = 32 * 30)
 
65
  # it is important that the resolution buckets properly cover the training dataset,
66
  # or else that we exclude from the dataset videos that are out of this range
67
  # right now, finetrainers will crash if that happens, so the workaround is to have more buckets in here
68
+
69
+ NB_FRAMES_1 = 1 # 1
70
+ NB_FRAMES_9 = 8 + 1 # 8 + 1
71
+ NB_FRAMES_17 = 8 * 2 + 1 # 16 + 1
72
+ NB_FRAMES_32 = 8 * 4 + 1 # 32 + 1
73
+ NB_FRAMES_48 = 8 * 6 + 1 # 48 + 1
74
+ NB_FRAMES_64 = 8 * 8 + 1 # 64 + 1
75
+ NB_FRAMES_80 = 8 * 10 + 1 # 80 + 1
76
+ NB_FRAMES_96 = 8 * 12 + 1 # 96 + 1
77
+ NB_FRAMES_112 = 8 * 14 + 1 # 112 + 1
78
+ NB_FRAMES_128 = 8 * 16 + 1 # 128 + 1
79
+ NB_FRAMES_144 = 8 * 18 + 1 # 144 + 1
80
+ NB_FRAMES_160 = 8 * 20 + 1 # 160 + 1
81
+ NB_FRAMES_176 = 8 * 22 + 1 # 176 + 1
82
+ NB_FRAMES_192 = 8 * 24 + 1 # 192 + 1
83
+ NB_FRAMES_224 = 8 * 28 + 1 # 224 + 1
84
+ NB_FRAMES_256 = 8 * 32 + 1 # 256 + 1
85
+ # 256 isn't a lot by the way, especially with 60 FPS videos..
86
+ # can we crank it and put more frames in here?
87
+
88
+ SMALL_TRAINING_BUCKETS = [
89
+ (NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
90
+ (NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
91
+ (NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
92
+ (NB_FRAMES_32, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
93
+ (NB_FRAMES_48, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
94
+ (NB_FRAMES_64, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
95
+ (NB_FRAMES_80, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
96
+ (NB_FRAMES_96, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
97
+ (NB_FRAMES_112, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
98
+ (NB_FRAMES_128, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
99
+ (NB_FRAMES_144, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
100
+ (NB_FRAMES_160, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
101
+ (NB_FRAMES_176, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
102
+ (NB_FRAMES_192, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
103
+ (NB_FRAMES_224, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
104
+ (NB_FRAMES_256, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
105
+ ]
106
+
107
+ MEDIUM_19_9_RATIO_WIDTH = 928 # 32 * 29
108
+ MEDIUM_19_9_RATIO_HEIGHT = 512 # 32 * 16
109
+
110
+ MEDIUM_19_9_RATIO_BUCKETS = [
111
+ (NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 1
112
+ (NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 8 + 1
113
+ (NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 16 + 1
114
+ (NB_FRAMES_32, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 32 + 1
115
+ (NB_FRAMES_48, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 48 + 1
116
+ (NB_FRAMES_64, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 64 + 1
117
+ (NB_FRAMES_80, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 80 + 1
118
+ (NB_FRAMES_96, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 96 + 1
119
+ (NB_FRAMES_112, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 112 + 1
120
+ (NB_FRAMES_128, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 128 + 1
121
+ (NB_FRAMES_144, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 144 + 1
122
+ (NB_FRAMES_160, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 160 + 1
123
+ (NB_FRAMES_176, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 176 + 1
124
+ (NB_FRAMES_192, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 192 + 1
125
+ (NB_FRAMES_224, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 224 + 1
126
+ (NB_FRAMES_256, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), # 256 + 1
127
  ]
128
 
129
+ TRAINING_PRESETS = {
130
+ "HunyuanVideo (normal)": {
131
+ "model_type": "hunyuan_video",
132
+ "lora_rank": "128",
133
+ "lora_alpha": "128",
134
+ "num_epochs": 70,
135
+ "batch_size": 1,
136
+ "learning_rate": 2e-5,
137
+ "save_iterations": 500,
138
+ "training_buckets": SMALL_TRAINING_BUCKETS,
139
+ },
140
+ "LTX-Video (normal)": {
141
+ "model_type": "ltx_video",
142
+ "lora_rank": "128",
143
+ "lora_alpha": "128",
144
+ "num_epochs": 70,
145
+ "batch_size": 1,
146
+ "learning_rate": 3e-5,
147
+ "save_iterations": 500,
148
+ "training_buckets": SMALL_TRAINING_BUCKETS,
149
+ },
150
+ "LTX-Video (16:9, HQ)": {
151
+ "model_type": "ltx_video",
152
+ "lora_rank": "256",
153
+ "lora_alpha": "128",
154
+ "num_epochs": 50,
155
+ "batch_size": 1,
156
+ "learning_rate": 3e-5,
157
+ "save_iterations": 200,
158
+ "training_buckets": MEDIUM_19_9_RATIO_BUCKETS,
159
+ }
160
+ }
161
+
162
  @dataclass
163
  class TrainingConfig:
164
  """Configuration class for finetrainers training"""
 
233
  nccl_timeout: int = 1800
234
 
235
  @classmethod
236
+ def hunyuan_video_lora(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig':
237
  """Configuration for Hunyuan video-to-video LoRA training"""
238
  return cls(
239
  model_name="hunyuan_video",
 
248
  gradient_accumulation_steps=1,
249
  lora_rank=128,
250
  lora_alpha=128,
251
+ video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
252
  caption_dropout_p=0.05,
253
  flow_weighting_scheme="none" # Hunyuan specific
254
  )
255
 
256
  @classmethod
257
+ def ltx_video_lora(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig':
258
  """Configuration for LTX-Video LoRA training"""
259
  return cls(
260
  model_name="ltx_video",
 
269
  gradient_accumulation_steps=4,
270
  lora_rank=128,
271
  lora_alpha=128,
272
+ video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS,
273
  caption_dropout_p=0.05,
274
  flow_weighting_scheme="logit_normal" # LTX specific
275
  )
training_service.py CHANGED
@@ -257,18 +257,25 @@ class TrainingService:
257
  logger.error(error_msg)
258
  return error_msg, "No training data available"
259
 
260
- # Get config for selected model type
 
 
 
 
 
261
  if model_type == "hunyuan_video":
262
  config = TrainingConfig.hunyuan_video_lora(
263
  data_path=str(TRAINING_PATH),
264
- output_path=str(OUTPUT_PATH)
 
265
  )
266
  else: # ltx_video
267
  config = TrainingConfig.ltx_video_lora(
268
  data_path=str(TRAINING_PATH),
269
- output_path=str(OUTPUT_PATH)
 
270
  )
271
-
272
  # Update with UI parameters
273
  config.train_epochs = int(num_epochs)
274
  config.lora_rank = int(lora_rank)
 
257
  logger.error(error_msg)
258
  return error_msg, "No training data available"
259
 
260
+
261
+ # Get preset configuration
262
+ preset = TRAINING_PRESETS[preset_name]
263
+ training_buckets = preset["training_buckets"]
264
+
265
+ # Get config for selected model type with preset buckets
266
  if model_type == "hunyuan_video":
267
  config = TrainingConfig.hunyuan_video_lora(
268
  data_path=str(TRAINING_PATH),
269
+ output_path=str(OUTPUT_PATH),
270
+ buckets=training_buckets
271
  )
272
  else: # ltx_video
273
  config = TrainingConfig.ltx_video_lora(
274
  data_path=str(TRAINING_PATH),
275
+ output_path=str(OUTPUT_PATH),
276
+ buckets=training_buckets
277
  )
278
+
279
  # Update with UI parameters
280
  config.train_epochs = int(num_epochs)
281
  config.lora_rank = int(lora_rank)