jbilcke-hf HF Staff commited on
Commit
ed18efe
·
1 Parent(s): c47044e

add some fixes to the UI

Browse files
finetrainers/args.py CHANGED
@@ -853,8 +853,9 @@ def _validate_dataset_args(args: BaseArgs):
853
 
854
 
855
  def _validate_validation_args(args: BaseArgs):
856
- if args.dp_shards > 1 and args.enable_model_cpu_offload:
857
- raise ValueError("Model CPU offload is not supported with FSDP at the moment.")
 
858
 
859
 
860
  def _display_helper_messages(args: argparse.Namespace):
 
853
 
854
 
855
  def _validate_validation_args(args: BaseArgs):
856
+ if args.enable_model_cpu_offload:
857
+ if any(x > 1 for x in [args.pp_degree, args.dp_degree, args.dp_shards, args.cp_degree, args.tp_degree]):
858
+ raise ValueError("Model CPU offload is not supported on multi-GPU at the moment.")
859
 
860
 
861
  def _display_helper_messages(args: argparse.Namespace):
finetrainers/patches/__init__.py CHANGED
@@ -17,7 +17,12 @@ def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBa
17
  if parallel_backend.tensor_parallel_enabled:
18
  patch.patch_apply_rotary_emb_for_tp_compatibility()
19
 
 
 
 
 
 
20
  if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
21
- from dependencies.peft import patch
22
 
23
  patch.patch_peft_move_adapter_to_device_of_base_layer()
 
17
  if parallel_backend.tensor_parallel_enabled:
18
  patch.patch_apply_rotary_emb_for_tp_compatibility()
19
 
20
+ if args.model_name == ModelType.WAN and "transformer" in args.layerwise_upcasting_modules:
21
+ from .models.wan import patch
22
+
23
+ patch.patch_time_text_image_embedding_forward()
24
+
25
  if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
26
+ from .dependencies.peft import patch
27
 
28
  patch.patch_peft_move_adapter_to_device_of_base_layer()
finetrainers/patches/models/ltx_video/patch.py CHANGED
@@ -16,7 +16,7 @@ def patch_apply_rotary_emb_for_tp_compatibility() -> None:
16
 
17
 
18
  def _perform_ltx_transformer_forward_patch() -> None:
19
- LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3Dforward
20
 
21
 
22
  def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
@@ -35,7 +35,7 @@ def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
35
  diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb
36
 
37
 
38
- def _patched_LTXVideoTransformer3Dforward(
39
  self,
40
  hidden_states: torch.Tensor,
41
  encoder_hidden_states: torch.Tensor,
 
16
 
17
 
18
  def _perform_ltx_transformer_forward_patch() -> None:
19
+ LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3D_forward
20
 
21
 
22
  def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
 
35
  diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb
36
 
37
 
38
+ def _patched_LTXVideoTransformer3D_forward(
39
  self,
40
  hidden_states: torch.Tensor,
41
  encoder_hidden_states: torch.Tensor,
finetrainers/patches/models/wan/patch.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import diffusers
4
+ import torch
5
+
6
+
7
+ def patch_time_text_image_embedding_forward() -> None:
8
+ _patch_time_text_image_embedding_forward()
9
+
10
+
11
+ def _patch_time_text_image_embedding_forward() -> None:
12
+ diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = (
13
+ _patched_WanTimeTextImageEmbedding_forward
14
+ )
15
+
16
+
17
+ def _patched_WanTimeTextImageEmbedding_forward(
18
+ self,
19
+ timestep: torch.Tensor,
20
+ encoder_hidden_states: torch.Tensor,
21
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
22
+ ):
23
+ # Some code has been removed compared to original implementation in Diffusers
24
+ # Also, timestep is typed as that of encoder_hidden_states
25
+ timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states)
26
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
27
+ timestep_proj = self.time_proj(self.act_fn(temb))
28
+
29
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
30
+ if encoder_hidden_states_image is not None:
31
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
32
+
33
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
finetrainers/trainer/sft_trainer/trainer.py CHANGED
@@ -334,6 +334,7 @@ class SFTTrainer:
334
  parallel_backend = self.state.parallel_backend
335
  train_state = self.state.train_state
336
  device = parallel_backend.device
 
337
 
338
  memory_statistics = utils.get_memory_statistics()
339
  logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
@@ -447,8 +448,8 @@ class SFTTrainer:
447
 
448
  logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})")
449
 
450
- utils.align_device_and_dtype(latent_model_conditions, device, self.args.transformer_dtype)
451
- utils.align_device_and_dtype(condition_model_conditions, device, self.args.transformer_dtype)
452
  latent_model_conditions = utils.make_contiguous(latent_model_conditions)
453
  condition_model_conditions = utils.make_contiguous(condition_model_conditions)
454
 
@@ -729,6 +730,7 @@ class SFTTrainer:
729
 
730
  parallel_backend.wait_for_everyone()
731
  if not final_validation:
 
732
  self.transformer.train()
733
 
734
  def _evaluate(self) -> None:
 
334
  parallel_backend = self.state.parallel_backend
335
  train_state = self.state.train_state
336
  device = parallel_backend.device
337
+ dtype = self.args.transformer_dtype
338
 
339
  memory_statistics = utils.get_memory_statistics()
340
  logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
 
448
 
449
  logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})")
450
 
451
+ latent_model_conditions = utils.align_device_and_dtype(latent_model_conditions, device, dtype)
452
+ condition_model_conditions = utils.align_device_and_dtype(condition_model_conditions, device, dtype)
453
  latent_model_conditions = utils.make_contiguous(latent_model_conditions)
454
  condition_model_conditions = utils.make_contiguous(condition_model_conditions)
455
 
 
730
 
731
  parallel_backend.wait_for_everyone()
732
  if not final_validation:
733
+ self._move_components_to_device()
734
  self.transformer.train()
735
 
736
  def _evaluate(self) -> None:
vms/ui/app_ui.py CHANGED
@@ -146,8 +146,8 @@ class AppUI:
146
  # Sidebar for navigation
147
  with gr.Sidebar(position="left", open=True):
148
  gr.Markdown("# 🎞️ Video Model Studio")
149
- self.components["current_project_btn"] = gr.Button("Current Project", variant="primary")
150
- self.components["system_monitoring_btn"] = gr.Button("System Monitoring")
151
 
152
  # Main content area with tabs
153
  with gr.Column():
@@ -174,7 +174,7 @@ class AppUI:
174
  tab_obj.create(project_tabs)
175
 
176
  # Monitoring View Tab
177
- with gr.Tab("📊 System Monitoring", id=1) as monitoring_view:
178
  # Create monitoring tabs
179
  with gr.Tabs() as monitoring_tabs:
180
  # Store reference to monitoring tabs component
@@ -257,19 +257,13 @@ class AppUI:
257
  self.project_tabs["train_tab"].components["stop_btn"],
258
  self.project_tabs["train_tab"].components["delete_checkpoints_btn"]
259
  ]
260
-
261
  button_timer.tick(
262
  fn=self.project_tabs["train_tab"].get_button_updates,
263
  outputs=button_outputs
264
  )
265
 
266
-
267
- # Add delete_checkpoints_btn or pause_resume_btn as the third button
268
- if "delete_checkpoints_btn" in self.project_tabs["train_tab"].components:
269
- button_outputs.append(self.project_tabs["train_tab"].components["delete_checkpoints_btn"])
270
- elif "pause_resume_btn" in self.project_tabs["train_tab"].components:
271
- button_outputs.append(self.project_tabs["train_tab"].components["pause_resume_btn"])
272
-
273
  # Dataset refresh timer (every 5 seconds)
274
  dataset_timer = gr.Timer(value=5)
275
  dataset_timer.tick(
@@ -558,21 +552,21 @@ class AppUI:
558
 
559
  if is_training:
560
  # Active training detected
561
- start_btn_props = {"interactive": False, "variant": "secondary", "value": "Start new training"}
562
- resume_btn_props = {"interactive": False, "variant": "secondary", "value": "Start from latest checkpoint"}
563
  stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
564
  delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
565
  else:
566
  # No active training
567
- start_btn_props = {"interactive": True, "variant": "primary", "value": "Start new training"}
568
- resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "Start from latest checkpoint"}
569
  stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
570
  delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
571
  else:
572
  # Use button states from recovery, adding the new resume button
573
- start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "Start new training"})
574
  resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(),
575
- "variant": "primary", "value": "Start from latest checkpoint"}
576
  stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
577
  delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
578
 
@@ -604,7 +598,7 @@ class AppUI:
604
 
605
  return (
606
  gr.Markdown(value=caption_title),
607
- gr.Markdown(value=f"{train_title} available for training")
608
  )
609
 
610
  def refresh_dataset(self):
 
146
  # Sidebar for navigation
147
  with gr.Sidebar(position="left", open=True):
148
  gr.Markdown("# 🎞️ Video Model Studio")
149
+ self.components["current_project_btn"] = gr.Button("📂 Current Project", variant="primary")
150
+ self.components["system_monitoring_btn"] = gr.Button("🌡️ System Monitoring")
151
 
152
  # Main content area with tabs
153
  with gr.Column():
 
174
  tab_obj.create(project_tabs)
175
 
176
  # Monitoring View Tab
177
+ with gr.Tab("🌡️ System Monitoring", id=1) as monitoring_view:
178
  # Create monitoring tabs
179
  with gr.Tabs() as monitoring_tabs:
180
  # Store reference to monitoring tabs component
 
257
  self.project_tabs["train_tab"].components["stop_btn"],
258
  self.project_tabs["train_tab"].components["delete_checkpoints_btn"]
259
  ]
260
+
261
  button_timer.tick(
262
  fn=self.project_tabs["train_tab"].get_button_updates,
263
  outputs=button_outputs
264
  )
265
 
266
+
 
 
 
 
 
 
267
  # Dataset refresh timer (every 5 seconds)
268
  dataset_timer = gr.Timer(value=5)
269
  dataset_timer.tick(
 
552
 
553
  if is_training:
554
  # Active training detected
555
+ start_btn_props = {"interactive": False, "variant": "secondary", "value": "🚀 Start new training"}
556
+ resume_btn_props = {"interactive": False, "variant": "secondary", "value": "🛰️ Start from latest checkpoint"}
557
  stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
558
  delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
559
  else:
560
  # No active training
561
+ start_btn_props = {"interactive": True, "variant": "primary", "value": "🚀 Start new training"}
562
+ resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "🛰️ Start from latest checkpoint"}
563
  stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
564
  delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
565
  else:
566
  # Use button states from recovery, adding the new resume button
567
+ start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "🚀 Start new training"})
568
  resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(),
569
+ "variant": "primary", "value": "🛰️ Start from latest checkpoint"}
570
  stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
571
  delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
572
 
 
598
 
599
  return (
600
  gr.Markdown(value=caption_title),
601
+ gr.Markdown(value=f"{train_title}")
602
  )
603
 
604
  def refresh_dataset(self):
vms/ui/monitoring/tabs/general_tab.py CHANGED
@@ -32,7 +32,7 @@ class GeneralTab(BaseTab):
32
  """Create the Monitor tab UI components"""
33
  with gr.TabItem(self.title, id=self.id) as tab:
34
  with gr.Row():
35
- gr.Markdown("## System Monitoring")
36
 
37
  # Current metrics
38
  with gr.Row():
 
32
  """Create the Monitor tab UI components"""
33
  with gr.TabItem(self.title, id=self.id) as tab:
34
  with gr.Row():
35
+ gr.Markdown("## 🌡️ System Monitoring")
36
 
37
  # Current metrics
38
  with gr.Row():
vms/ui/project/tabs/manage_tab.py CHANGED
@@ -29,23 +29,23 @@ class ManageTab(BaseTab):
29
  with gr.TabItem(self.title, id=self.id) as tab:
30
  with gr.Row():
31
  with gr.Column():
32
- gr.Markdown("## Download your model")
33
  gr.Markdown("There is currently a bug, you might have to click multiple times to trigger a download.")
34
 
35
  with gr.Row():
36
  self.components["download_dataset_btn"] = gr.DownloadButton(
37
- "Download training dataset",
38
  variant="secondary",
39
  size="lg"
40
  )
41
  self.components["download_model_btn"] = gr.DownloadButton(
42
- "Download model weights",
43
  variant="secondary",
44
  size="lg"
45
  )
46
  with gr.Row():
47
  with gr.Column():
48
- gr.Markdown("## Publish your model")
49
  gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
50
 
51
  with gr.Row():
@@ -65,19 +65,19 @@ class ManageTab(BaseTab):
65
 
66
  with gr.Row():
67
  with gr.Column():
68
- gr.Markdown("## Delete your data")
69
  gr.Markdown("Make sure you have made a backup first.")
70
  gr.Markdown("If you are deleting because of a bug, remember you can use the Developer Mode on HF to inspect the working directory (in /data or .data)")
71
 
72
  with gr.Row():
73
  with gr.Column():
74
- gr.Markdown("### Delete specific data")
75
  gr.Markdown("You can selectively delete either the dataset and/or the last model data.")
76
 
77
  with gr.Row():
78
  with gr.Column(scale=1):
79
  self.components["delete_dataset_btn"] = gr.Button(
80
- "Delete dataset (images, video, captions)",
81
  variant="secondary"
82
  )
83
  self.components["delete_dataset_status"] = gr.Textbox(
@@ -88,7 +88,7 @@ class ManageTab(BaseTab):
88
 
89
  with gr.Column(scale=1):
90
  self.components["delete_model_btn"] = gr.Button(
91
- "Delete model (checkpoints, weights, config)",
92
  variant="secondary"
93
  )
94
  self.components["delete_model_status"] = gr.Textbox(
@@ -99,12 +99,12 @@ class ManageTab(BaseTab):
99
 
100
  with gr.Row():
101
  with gr.Column():
102
- gr.Markdown("### Delete everything")
103
- gr.Markdown("This will delete both the dataset (all images, videos and captions) AND the latest model (weights, checkpoints, settings). So use with care!")
104
 
105
  with gr.Row():
106
  self.components["global_stop_btn"] = gr.Button(
107
- "Stop everything and delete my data",
108
  variant="stop"
109
  )
110
  self.components["global_status"] = gr.Textbox(
 
29
  with gr.TabItem(self.title, id=self.id) as tab:
30
  with gr.Row():
31
  with gr.Column():
32
+ gr.Markdown("## 🏦 Backup your model")
33
  gr.Markdown("There is currently a bug, you might have to click multiple times to trigger a download.")
34
 
35
  with gr.Row():
36
  self.components["download_dataset_btn"] = gr.DownloadButton(
37
+ "📦 Download training dataset (.zip)",
38
  variant="secondary",
39
  size="lg"
40
  )
41
  self.components["download_model_btn"] = gr.DownloadButton(
42
+ "🧠 Download weights (.safetensors)",
43
  variant="secondary",
44
  size="lg"
45
  )
46
  with gr.Row():
47
  with gr.Column():
48
+ gr.Markdown("## 📡 Publish your model")
49
  gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
50
 
51
  with gr.Row():
 
65
 
66
  with gr.Row():
67
  with gr.Column():
68
+ gr.Markdown("## ♻️ Delete your data")
69
  gr.Markdown("Make sure you have made a backup first.")
70
  gr.Markdown("If you are deleting because of a bug, remember you can use the Developer Mode on HF to inspect the working directory (in /data or .data)")
71
 
72
  with gr.Row():
73
  with gr.Column():
74
+ gr.Markdown("### 🧽 Delete specific data")
75
  gr.Markdown("You can selectively delete either the dataset and/or the last model data.")
76
 
77
  with gr.Row():
78
  with gr.Column(scale=1):
79
  self.components["delete_dataset_btn"] = gr.Button(
80
+ "🚨 Delete dataset (images, video, captions)",
81
  variant="secondary"
82
  )
83
  self.components["delete_dataset_status"] = gr.Textbox(
 
88
 
89
  with gr.Column(scale=1):
90
  self.components["delete_model_btn"] = gr.Button(
91
+ "🚨 Delete model (checkpoints, weights, config)",
92
  variant="secondary"
93
  )
94
  self.components["delete_model_status"] = gr.Textbox(
 
99
 
100
  with gr.Row():
101
  with gr.Column():
102
+ gr.Markdown("### ☢️ Nuke all project data")
103
+ gr.Markdown("This will nuke the original dataset (all images, videos and captions), the training dataset, and the model outputs (weights, checkpoints, settings). So use with care!")
104
 
105
  with gr.Row():
106
  self.components["global_stop_btn"] = gr.Button(
107
+ "🚨 Delete all project data and models (are you sure?!)",
108
  variant="stop"
109
  )
110
  self.components["global_status"] = gr.Textbox(
vms/ui/project/tabs/preview_tab.py CHANGED
@@ -29,7 +29,7 @@ class PreviewTab(BaseTab):
29
  """Create the Preview tab UI components"""
30
  with gr.TabItem(self.title, id=self.id) as tab:
31
  with gr.Row():
32
- gr.Markdown("## Preview your model")
33
 
34
  with gr.Row():
35
  with gr.Column(scale=2):
@@ -202,11 +202,11 @@ class PreviewTab(BaseTab):
202
  interactive=False
203
  )
204
 
205
- with gr.Accordion("Log", open=True):
206
  self.components["log"] = gr.TextArea(
207
  label="Generation Log",
208
  interactive=False,
209
- lines=20
210
  )
211
 
212
  return tab
 
29
  """Create the Preview tab UI components"""
30
  with gr.TabItem(self.title, id=self.id) as tab:
31
  with gr.Row():
32
+ gr.Markdown("## 🔬 Preview your model")
33
 
34
  with gr.Row():
35
  with gr.Column(scale=2):
 
202
  interactive=False
203
  )
204
 
205
+ with gr.Accordion("Log", open=False):
206
  self.components["log"] = gr.TextArea(
207
  label="Generation Log",
208
  interactive=False,
209
+ lines=60
210
  )
211
 
212
  return tab
vms/ui/project/tabs/train_tab.py CHANGED
@@ -44,7 +44,7 @@ class TrainTab(BaseTab):
44
  with gr.Row():
45
  with gr.Column():
46
  with gr.Row():
47
- self.components["train_title"] = gr.Markdown("## 0 files available for training (0 bytes)")
48
 
49
  with gr.Row():
50
  with gr.Column():
@@ -181,79 +181,80 @@ class TrainTab(BaseTab):
181
 
182
  with gr.Row():
183
  with gr.Column():
184
- # Add description of the training buttons
185
- self.components["training_buttons_info"] = gr.Markdown("""
186
- ## Training Options
187
- - **Start new training**: Begins training from scratch (clears previous checkpoints)
188
- - **Start from latest checkpoint**: Continues training from the most recent checkpoint
189
- """)
190
-
191
- with gr.Row():
192
- # Check for existing checkpoints to determine button text
193
- checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
194
- has_checkpoints = len(checkpoints) > 0
195
 
196
- # Rename "Start Training" to "Start new training"
197
- self.components["start_btn"] = gr.Button(
198
- "Start new training",
199
- variant="primary",
200
- interactive=not ASK_USER_TO_DUPLICATE_SPACE
201
- )
202
-
203
- # Add new button for continuing from checkpoint
204
- self.components["resume_btn"] = gr.Button(
205
- "Start from latest checkpoint",
206
- variant="primary",
207
- interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
208
- )
209
-
210
- with gr.Row():
211
- # Just use stop and pause buttons for now to ensure compatibility
212
- self.components["stop_btn"] = gr.Button(
213
- "Stop at Last Checkpoint",
214
- variant="primary",
215
- interactive=False
216
- )
217
-
218
- self.components["pause_resume_btn"] = gr.Button(
219
- "Resume Training",
220
- variant="secondary",
221
- interactive=False,
222
- visible=False
223
- )
224
-
225
- # Add delete checkpoints button
226
- self.components["delete_checkpoints_btn"] = gr.Button(
227
- "Delete All Checkpoints",
228
- variant="stop",
229
- interactive=has_checkpoints
230
- )
231
 
232
- with gr.Column():
233
- with gr.Row():
234
- with gr.Column():
235
- self.components["status_box"] = gr.Textbox(
236
- label="Training Status",
237
- interactive=False,
238
- lines=4
239
- )
240
-
241
- # Add new component for current task progress
242
- self.components["current_task_box"] = gr.Textbox(
243
- label="Current Task Progress",
244
- interactive=False,
245
- lines=3,
246
- elem_id="current_task_display"
247
- )
248
-
249
- with gr.Accordion("Finetrainers output (or see app logs for more details)"):
250
- self.components["log_box"] = gr.TextArea(
251
- #label="",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  interactive=False,
253
- lines=60,
254
- max_lines=600,
255
- autoscroll=True
256
  )
 
 
 
 
 
 
 
 
 
257
 
258
  return tab
259
 
@@ -963,7 +964,7 @@ class TrainTab(BaseTab):
963
 
964
  # Create button updates
965
  start_btn = gr.Button(
966
- value="Start new training",
967
  interactive=not is_training,
968
  variant="primary" if not is_training else "secondary"
969
  )
 
44
  with gr.Row():
45
  with gr.Column():
46
  with gr.Row():
47
+ self.components["train_title"] = gr.Markdown("## 0 files in the training dataset")
48
 
49
  with gr.Row():
50
  with gr.Column():
 
181
 
182
  with gr.Row():
183
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ with gr.Row():
186
+ with gr.Column():
187
+ # Add description of the training buttons
188
+ self.components["training_buttons_info"] = gr.Markdown("""
189
+ ## ⚗️ Train your model on your dataset
190
+ - **Start new training**: Begins training from scratch (clears previous checkpoints)
191
+ - **Start from latest checkpoint**: Continues training from the most recent checkpoint
192
+ """)
193
+
194
+ with gr.Row():
195
+ # Check for existing checkpoints to determine button text
196
+ checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
197
+ has_checkpoints = len(checkpoints) > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ self.components["start_btn"] = gr.Button(
200
+ "🚀 Start new training",
201
+ variant="primary",
202
+ interactive=not ASK_USER_TO_DUPLICATE_SPACE
203
+ )
204
+
205
+ # Add new button for continuing from checkpoint
206
+ self.components["resume_btn"] = gr.Button(
207
+ "🛰️ Start from latest checkpoint",
208
+ variant="primary",
209
+ interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
210
+ )
211
+
212
+ with gr.Row():
213
+ # Just use stop and pause buttons for now to ensure compatibility
214
+ self.components["stop_btn"] = gr.Button(
215
+ "Stop at Last Checkpoint",
216
+ variant="primary",
217
+ interactive=False
218
+ )
219
+
220
+ self.components["pause_resume_btn"] = gr.Button(
221
+ "Resume Training",
222
+ variant="secondary",
223
+ interactive=False,
224
+ visible=False
225
+ )
226
+
227
+ # Add delete checkpoints button
228
+ self.components["delete_checkpoints_btn"] = gr.Button(
229
+ "Delete All Checkpoints",
230
+ variant="stop",
231
+ interactive=has_checkpoints
232
+ )
233
+
234
+ with gr.Row():
235
+ with gr.Column():
236
+ self.components["status_box"] = gr.Textbox(
237
+ label="Training Status",
238
+ interactive=False,
239
+ lines=4
240
+ )
241
+
242
+ # Add new component for current task progress
243
+ self.components["current_task_box"] = gr.Textbox(
244
+ label="Current Task Progress",
245
  interactive=False,
246
+ lines=3,
247
+ elem_id="current_task_display"
 
248
  )
249
+
250
+ with gr.Accordion("Finetrainers output (or see app logs for more details)", open=False):
251
+ self.components["log_box"] = gr.TextArea(
252
+ #label="",
253
+ interactive=False,
254
+ lines=60,
255
+ max_lines=600,
256
+ autoscroll=True
257
+ )
258
 
259
  return tab
260
 
 
964
 
965
  # Create button updates
966
  start_btn = gr.Button(
967
+ value="🚀 Start new training",
968
  interactive=not is_training,
969
  variant="primary" if not is_training else "secondary"
970
  )