Spaces:
Running
Running
Commit
·
ed18efe
1
Parent(s):
c47044e
add some fixes to the UI
Browse files- finetrainers/args.py +3 -2
- finetrainers/patches/__init__.py +6 -1
- finetrainers/patches/models/ltx_video/patch.py +2 -2
- finetrainers/patches/models/wan/patch.py +33 -0
- finetrainers/trainer/sft_trainer/trainer.py +4 -2
- vms/ui/app_ui.py +12 -18
- vms/ui/monitoring/tabs/general_tab.py +1 -1
- vms/ui/project/tabs/manage_tab.py +11 -11
- vms/ui/project/tabs/preview_tab.py +3 -3
- vms/ui/project/tabs/train_tab.py +72 -71
finetrainers/args.py
CHANGED
@@ -853,8 +853,9 @@ def _validate_dataset_args(args: BaseArgs):
|
|
853 |
|
854 |
|
855 |
def _validate_validation_args(args: BaseArgs):
|
856 |
-
if args.
|
857 |
-
|
|
|
858 |
|
859 |
|
860 |
def _display_helper_messages(args: argparse.Namespace):
|
|
|
853 |
|
854 |
|
855 |
def _validate_validation_args(args: BaseArgs):
|
856 |
+
if args.enable_model_cpu_offload:
|
857 |
+
if any(x > 1 for x in [args.pp_degree, args.dp_degree, args.dp_shards, args.cp_degree, args.tp_degree]):
|
858 |
+
raise ValueError("Model CPU offload is not supported on multi-GPU at the moment.")
|
859 |
|
860 |
|
861 |
def _display_helper_messages(args: argparse.Namespace):
|
finetrainers/patches/__init__.py
CHANGED
@@ -17,7 +17,12 @@ def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBa
|
|
17 |
if parallel_backend.tensor_parallel_enabled:
|
18 |
patch.patch_apply_rotary_emb_for_tp_compatibility()
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
|
21 |
-
from dependencies.peft import patch
|
22 |
|
23 |
patch.patch_peft_move_adapter_to_device_of_base_layer()
|
|
|
17 |
if parallel_backend.tensor_parallel_enabled:
|
18 |
patch.patch_apply_rotary_emb_for_tp_compatibility()
|
19 |
|
20 |
+
if args.model_name == ModelType.WAN and "transformer" in args.layerwise_upcasting_modules:
|
21 |
+
from .models.wan import patch
|
22 |
+
|
23 |
+
patch.patch_time_text_image_embedding_forward()
|
24 |
+
|
25 |
if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
|
26 |
+
from .dependencies.peft import patch
|
27 |
|
28 |
patch.patch_peft_move_adapter_to_device_of_base_layer()
|
finetrainers/patches/models/ltx_video/patch.py
CHANGED
@@ -16,7 +16,7 @@ def patch_apply_rotary_emb_for_tp_compatibility() -> None:
|
|
16 |
|
17 |
|
18 |
def _perform_ltx_transformer_forward_patch() -> None:
|
19 |
-
LTXVideoTransformer3DModel.forward =
|
20 |
|
21 |
|
22 |
def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
|
@@ -35,7 +35,7 @@ def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
|
|
35 |
diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb
|
36 |
|
37 |
|
38 |
-
def
|
39 |
self,
|
40 |
hidden_states: torch.Tensor,
|
41 |
encoder_hidden_states: torch.Tensor,
|
|
|
16 |
|
17 |
|
18 |
def _perform_ltx_transformer_forward_patch() -> None:
|
19 |
+
LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3D_forward
|
20 |
|
21 |
|
22 |
def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
|
|
|
35 |
diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb
|
36 |
|
37 |
|
38 |
+
def _patched_LTXVideoTransformer3D_forward(
|
39 |
self,
|
40 |
hidden_states: torch.Tensor,
|
41 |
encoder_hidden_states: torch.Tensor,
|
finetrainers/patches/models/wan/patch.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import diffusers
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def patch_time_text_image_embedding_forward() -> None:
|
8 |
+
_patch_time_text_image_embedding_forward()
|
9 |
+
|
10 |
+
|
11 |
+
def _patch_time_text_image_embedding_forward() -> None:
|
12 |
+
diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = (
|
13 |
+
_patched_WanTimeTextImageEmbedding_forward
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
def _patched_WanTimeTextImageEmbedding_forward(
|
18 |
+
self,
|
19 |
+
timestep: torch.Tensor,
|
20 |
+
encoder_hidden_states: torch.Tensor,
|
21 |
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
22 |
+
):
|
23 |
+
# Some code has been removed compared to original implementation in Diffusers
|
24 |
+
# Also, timestep is typed as that of encoder_hidden_states
|
25 |
+
timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states)
|
26 |
+
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
27 |
+
timestep_proj = self.time_proj(self.act_fn(temb))
|
28 |
+
|
29 |
+
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
30 |
+
if encoder_hidden_states_image is not None:
|
31 |
+
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
32 |
+
|
33 |
+
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
finetrainers/trainer/sft_trainer/trainer.py
CHANGED
@@ -334,6 +334,7 @@ class SFTTrainer:
|
|
334 |
parallel_backend = self.state.parallel_backend
|
335 |
train_state = self.state.train_state
|
336 |
device = parallel_backend.device
|
|
|
337 |
|
338 |
memory_statistics = utils.get_memory_statistics()
|
339 |
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
|
@@ -447,8 +448,8 @@ class SFTTrainer:
|
|
447 |
|
448 |
logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})")
|
449 |
|
450 |
-
utils.align_device_and_dtype(latent_model_conditions, device,
|
451 |
-
utils.align_device_and_dtype(condition_model_conditions, device,
|
452 |
latent_model_conditions = utils.make_contiguous(latent_model_conditions)
|
453 |
condition_model_conditions = utils.make_contiguous(condition_model_conditions)
|
454 |
|
@@ -729,6 +730,7 @@ class SFTTrainer:
|
|
729 |
|
730 |
parallel_backend.wait_for_everyone()
|
731 |
if not final_validation:
|
|
|
732 |
self.transformer.train()
|
733 |
|
734 |
def _evaluate(self) -> None:
|
|
|
334 |
parallel_backend = self.state.parallel_backend
|
335 |
train_state = self.state.train_state
|
336 |
device = parallel_backend.device
|
337 |
+
dtype = self.args.transformer_dtype
|
338 |
|
339 |
memory_statistics = utils.get_memory_statistics()
|
340 |
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
|
|
|
448 |
|
449 |
logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})")
|
450 |
|
451 |
+
latent_model_conditions = utils.align_device_and_dtype(latent_model_conditions, device, dtype)
|
452 |
+
condition_model_conditions = utils.align_device_and_dtype(condition_model_conditions, device, dtype)
|
453 |
latent_model_conditions = utils.make_contiguous(latent_model_conditions)
|
454 |
condition_model_conditions = utils.make_contiguous(condition_model_conditions)
|
455 |
|
|
|
730 |
|
731 |
parallel_backend.wait_for_everyone()
|
732 |
if not final_validation:
|
733 |
+
self._move_components_to_device()
|
734 |
self.transformer.train()
|
735 |
|
736 |
def _evaluate(self) -> None:
|
vms/ui/app_ui.py
CHANGED
@@ -146,8 +146,8 @@ class AppUI:
|
|
146 |
# Sidebar for navigation
|
147 |
with gr.Sidebar(position="left", open=True):
|
148 |
gr.Markdown("# 🎞️ Video Model Studio")
|
149 |
-
self.components["current_project_btn"] = gr.Button("Current Project", variant="primary")
|
150 |
-
self.components["system_monitoring_btn"] = gr.Button("System Monitoring")
|
151 |
|
152 |
# Main content area with tabs
|
153 |
with gr.Column():
|
@@ -174,7 +174,7 @@ class AppUI:
|
|
174 |
tab_obj.create(project_tabs)
|
175 |
|
176 |
# Monitoring View Tab
|
177 |
-
with gr.Tab("
|
178 |
# Create monitoring tabs
|
179 |
with gr.Tabs() as monitoring_tabs:
|
180 |
# Store reference to monitoring tabs component
|
@@ -257,19 +257,13 @@ class AppUI:
|
|
257 |
self.project_tabs["train_tab"].components["stop_btn"],
|
258 |
self.project_tabs["train_tab"].components["delete_checkpoints_btn"]
|
259 |
]
|
260 |
-
|
261 |
button_timer.tick(
|
262 |
fn=self.project_tabs["train_tab"].get_button_updates,
|
263 |
outputs=button_outputs
|
264 |
)
|
265 |
|
266 |
-
|
267 |
-
# Add delete_checkpoints_btn or pause_resume_btn as the third button
|
268 |
-
if "delete_checkpoints_btn" in self.project_tabs["train_tab"].components:
|
269 |
-
button_outputs.append(self.project_tabs["train_tab"].components["delete_checkpoints_btn"])
|
270 |
-
elif "pause_resume_btn" in self.project_tabs["train_tab"].components:
|
271 |
-
button_outputs.append(self.project_tabs["train_tab"].components["pause_resume_btn"])
|
272 |
-
|
273 |
# Dataset refresh timer (every 5 seconds)
|
274 |
dataset_timer = gr.Timer(value=5)
|
275 |
dataset_timer.tick(
|
@@ -558,21 +552,21 @@ class AppUI:
|
|
558 |
|
559 |
if is_training:
|
560 |
# Active training detected
|
561 |
-
start_btn_props = {"interactive": False, "variant": "secondary", "value": "Start new training"}
|
562 |
-
resume_btn_props = {"interactive": False, "variant": "secondary", "value": "Start from latest checkpoint"}
|
563 |
stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
|
564 |
delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
|
565 |
else:
|
566 |
# No active training
|
567 |
-
start_btn_props = {"interactive": True, "variant": "primary", "value": "Start new training"}
|
568 |
-
resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "Start from latest checkpoint"}
|
569 |
stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
|
570 |
delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
|
571 |
else:
|
572 |
# Use button states from recovery, adding the new resume button
|
573 |
-
start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "Start new training"})
|
574 |
resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(),
|
575 |
-
"variant": "primary", "value": "Start from latest checkpoint"}
|
576 |
stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
|
577 |
delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
|
578 |
|
@@ -604,7 +598,7 @@ class AppUI:
|
|
604 |
|
605 |
return (
|
606 |
gr.Markdown(value=caption_title),
|
607 |
-
gr.Markdown(value=f"{train_title}
|
608 |
)
|
609 |
|
610 |
def refresh_dataset(self):
|
|
|
146 |
# Sidebar for navigation
|
147 |
with gr.Sidebar(position="left", open=True):
|
148 |
gr.Markdown("# 🎞️ Video Model Studio")
|
149 |
+
self.components["current_project_btn"] = gr.Button("📂 Current Project", variant="primary")
|
150 |
+
self.components["system_monitoring_btn"] = gr.Button("🌡️ System Monitoring")
|
151 |
|
152 |
# Main content area with tabs
|
153 |
with gr.Column():
|
|
|
174 |
tab_obj.create(project_tabs)
|
175 |
|
176 |
# Monitoring View Tab
|
177 |
+
with gr.Tab("🌡️ System Monitoring", id=1) as monitoring_view:
|
178 |
# Create monitoring tabs
|
179 |
with gr.Tabs() as monitoring_tabs:
|
180 |
# Store reference to monitoring tabs component
|
|
|
257 |
self.project_tabs["train_tab"].components["stop_btn"],
|
258 |
self.project_tabs["train_tab"].components["delete_checkpoints_btn"]
|
259 |
]
|
260 |
+
|
261 |
button_timer.tick(
|
262 |
fn=self.project_tabs["train_tab"].get_button_updates,
|
263 |
outputs=button_outputs
|
264 |
)
|
265 |
|
266 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
# Dataset refresh timer (every 5 seconds)
|
268 |
dataset_timer = gr.Timer(value=5)
|
269 |
dataset_timer.tick(
|
|
|
552 |
|
553 |
if is_training:
|
554 |
# Active training detected
|
555 |
+
start_btn_props = {"interactive": False, "variant": "secondary", "value": "🚀 Start new training"}
|
556 |
+
resume_btn_props = {"interactive": False, "variant": "secondary", "value": "🛰️ Start from latest checkpoint"}
|
557 |
stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
|
558 |
delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
|
559 |
else:
|
560 |
# No active training
|
561 |
+
start_btn_props = {"interactive": True, "variant": "primary", "value": "🚀 Start new training"}
|
562 |
+
resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "🛰️ Start from latest checkpoint"}
|
563 |
stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
|
564 |
delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
|
565 |
else:
|
566 |
# Use button states from recovery, adding the new resume button
|
567 |
+
start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "🚀 Start new training"})
|
568 |
resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(),
|
569 |
+
"variant": "primary", "value": "🛰️ Start from latest checkpoint"}
|
570 |
stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
|
571 |
delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
|
572 |
|
|
|
598 |
|
599 |
return (
|
600 |
gr.Markdown(value=caption_title),
|
601 |
+
gr.Markdown(value=f"{train_title}")
|
602 |
)
|
603 |
|
604 |
def refresh_dataset(self):
|
vms/ui/monitoring/tabs/general_tab.py
CHANGED
@@ -32,7 +32,7 @@ class GeneralTab(BaseTab):
|
|
32 |
"""Create the Monitor tab UI components"""
|
33 |
with gr.TabItem(self.title, id=self.id) as tab:
|
34 |
with gr.Row():
|
35 |
-
gr.Markdown("## System Monitoring")
|
36 |
|
37 |
# Current metrics
|
38 |
with gr.Row():
|
|
|
32 |
"""Create the Monitor tab UI components"""
|
33 |
with gr.TabItem(self.title, id=self.id) as tab:
|
34 |
with gr.Row():
|
35 |
+
gr.Markdown("## 🌡️ System Monitoring")
|
36 |
|
37 |
# Current metrics
|
38 |
with gr.Row():
|
vms/ui/project/tabs/manage_tab.py
CHANGED
@@ -29,23 +29,23 @@ class ManageTab(BaseTab):
|
|
29 |
with gr.TabItem(self.title, id=self.id) as tab:
|
30 |
with gr.Row():
|
31 |
with gr.Column():
|
32 |
-
gr.Markdown("##
|
33 |
gr.Markdown("There is currently a bug, you might have to click multiple times to trigger a download.")
|
34 |
|
35 |
with gr.Row():
|
36 |
self.components["download_dataset_btn"] = gr.DownloadButton(
|
37 |
-
"Download training dataset",
|
38 |
variant="secondary",
|
39 |
size="lg"
|
40 |
)
|
41 |
self.components["download_model_btn"] = gr.DownloadButton(
|
42 |
-
"Download
|
43 |
variant="secondary",
|
44 |
size="lg"
|
45 |
)
|
46 |
with gr.Row():
|
47 |
with gr.Column():
|
48 |
-
gr.Markdown("## Publish your model")
|
49 |
gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
|
50 |
|
51 |
with gr.Row():
|
@@ -65,19 +65,19 @@ class ManageTab(BaseTab):
|
|
65 |
|
66 |
with gr.Row():
|
67 |
with gr.Column():
|
68 |
-
gr.Markdown("## Delete your data")
|
69 |
gr.Markdown("Make sure you have made a backup first.")
|
70 |
gr.Markdown("If you are deleting because of a bug, remember you can use the Developer Mode on HF to inspect the working directory (in /data or .data)")
|
71 |
|
72 |
with gr.Row():
|
73 |
with gr.Column():
|
74 |
-
gr.Markdown("### Delete specific data")
|
75 |
gr.Markdown("You can selectively delete either the dataset and/or the last model data.")
|
76 |
|
77 |
with gr.Row():
|
78 |
with gr.Column(scale=1):
|
79 |
self.components["delete_dataset_btn"] = gr.Button(
|
80 |
-
"Delete dataset (images, video, captions)",
|
81 |
variant="secondary"
|
82 |
)
|
83 |
self.components["delete_dataset_status"] = gr.Textbox(
|
@@ -88,7 +88,7 @@ class ManageTab(BaseTab):
|
|
88 |
|
89 |
with gr.Column(scale=1):
|
90 |
self.components["delete_model_btn"] = gr.Button(
|
91 |
-
"Delete model (checkpoints, weights, config)",
|
92 |
variant="secondary"
|
93 |
)
|
94 |
self.components["delete_model_status"] = gr.Textbox(
|
@@ -99,12 +99,12 @@ class ManageTab(BaseTab):
|
|
99 |
|
100 |
with gr.Row():
|
101 |
with gr.Column():
|
102 |
-
gr.Markdown("###
|
103 |
-
gr.Markdown("This will
|
104 |
|
105 |
with gr.Row():
|
106 |
self.components["global_stop_btn"] = gr.Button(
|
107 |
-
"
|
108 |
variant="stop"
|
109 |
)
|
110 |
self.components["global_status"] = gr.Textbox(
|
|
|
29 |
with gr.TabItem(self.title, id=self.id) as tab:
|
30 |
with gr.Row():
|
31 |
with gr.Column():
|
32 |
+
gr.Markdown("## 🏦 Backup your model")
|
33 |
gr.Markdown("There is currently a bug, you might have to click multiple times to trigger a download.")
|
34 |
|
35 |
with gr.Row():
|
36 |
self.components["download_dataset_btn"] = gr.DownloadButton(
|
37 |
+
"📦 Download training dataset (.zip)",
|
38 |
variant="secondary",
|
39 |
size="lg"
|
40 |
)
|
41 |
self.components["download_model_btn"] = gr.DownloadButton(
|
42 |
+
"🧠 Download weights (.safetensors)",
|
43 |
variant="secondary",
|
44 |
size="lg"
|
45 |
)
|
46 |
with gr.Row():
|
47 |
with gr.Column():
|
48 |
+
gr.Markdown("## 📡 Publish your model")
|
49 |
gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
|
50 |
|
51 |
with gr.Row():
|
|
|
65 |
|
66 |
with gr.Row():
|
67 |
with gr.Column():
|
68 |
+
gr.Markdown("## ♻️ Delete your data")
|
69 |
gr.Markdown("Make sure you have made a backup first.")
|
70 |
gr.Markdown("If you are deleting because of a bug, remember you can use the Developer Mode on HF to inspect the working directory (in /data or .data)")
|
71 |
|
72 |
with gr.Row():
|
73 |
with gr.Column():
|
74 |
+
gr.Markdown("### 🧽 Delete specific data")
|
75 |
gr.Markdown("You can selectively delete either the dataset and/or the last model data.")
|
76 |
|
77 |
with gr.Row():
|
78 |
with gr.Column(scale=1):
|
79 |
self.components["delete_dataset_btn"] = gr.Button(
|
80 |
+
"🚨 Delete dataset (images, video, captions)",
|
81 |
variant="secondary"
|
82 |
)
|
83 |
self.components["delete_dataset_status"] = gr.Textbox(
|
|
|
88 |
|
89 |
with gr.Column(scale=1):
|
90 |
self.components["delete_model_btn"] = gr.Button(
|
91 |
+
"🚨 Delete model (checkpoints, weights, config)",
|
92 |
variant="secondary"
|
93 |
)
|
94 |
self.components["delete_model_status"] = gr.Textbox(
|
|
|
99 |
|
100 |
with gr.Row():
|
101 |
with gr.Column():
|
102 |
+
gr.Markdown("### ☢️ Nuke all project data")
|
103 |
+
gr.Markdown("This will nuke the original dataset (all images, videos and captions), the training dataset, and the model outputs (weights, checkpoints, settings). So use with care!")
|
104 |
|
105 |
with gr.Row():
|
106 |
self.components["global_stop_btn"] = gr.Button(
|
107 |
+
"🚨 Delete all project data and models (are you sure?!)",
|
108 |
variant="stop"
|
109 |
)
|
110 |
self.components["global_status"] = gr.Textbox(
|
vms/ui/project/tabs/preview_tab.py
CHANGED
@@ -29,7 +29,7 @@ class PreviewTab(BaseTab):
|
|
29 |
"""Create the Preview tab UI components"""
|
30 |
with gr.TabItem(self.title, id=self.id) as tab:
|
31 |
with gr.Row():
|
32 |
-
gr.Markdown("## Preview your model")
|
33 |
|
34 |
with gr.Row():
|
35 |
with gr.Column(scale=2):
|
@@ -202,11 +202,11 @@ class PreviewTab(BaseTab):
|
|
202 |
interactive=False
|
203 |
)
|
204 |
|
205 |
-
with gr.Accordion("Log", open=
|
206 |
self.components["log"] = gr.TextArea(
|
207 |
label="Generation Log",
|
208 |
interactive=False,
|
209 |
-
lines=
|
210 |
)
|
211 |
|
212 |
return tab
|
|
|
29 |
"""Create the Preview tab UI components"""
|
30 |
with gr.TabItem(self.title, id=self.id) as tab:
|
31 |
with gr.Row():
|
32 |
+
gr.Markdown("## 🔬 Preview your model")
|
33 |
|
34 |
with gr.Row():
|
35 |
with gr.Column(scale=2):
|
|
|
202 |
interactive=False
|
203 |
)
|
204 |
|
205 |
+
with gr.Accordion("Log", open=False):
|
206 |
self.components["log"] = gr.TextArea(
|
207 |
label="Generation Log",
|
208 |
interactive=False,
|
209 |
+
lines=60
|
210 |
)
|
211 |
|
212 |
return tab
|
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -44,7 +44,7 @@ class TrainTab(BaseTab):
|
|
44 |
with gr.Row():
|
45 |
with gr.Column():
|
46 |
with gr.Row():
|
47 |
-
self.components["train_title"] = gr.Markdown("## 0 files
|
48 |
|
49 |
with gr.Row():
|
50 |
with gr.Column():
|
@@ -181,79 +181,80 @@ class TrainTab(BaseTab):
|
|
181 |
|
182 |
with gr.Row():
|
183 |
with gr.Column():
|
184 |
-
# Add description of the training buttons
|
185 |
-
self.components["training_buttons_info"] = gr.Markdown("""
|
186 |
-
## Training Options
|
187 |
-
- **Start new training**: Begins training from scratch (clears previous checkpoints)
|
188 |
-
- **Start from latest checkpoint**: Continues training from the most recent checkpoint
|
189 |
-
""")
|
190 |
-
|
191 |
-
with gr.Row():
|
192 |
-
# Check for existing checkpoints to determine button text
|
193 |
-
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
|
194 |
-
has_checkpoints = len(checkpoints) > 0
|
195 |
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
with gr.Row():
|
211 |
-
# Just use stop and pause buttons for now to ensure compatibility
|
212 |
-
self.components["stop_btn"] = gr.Button(
|
213 |
-
"Stop at Last Checkpoint",
|
214 |
-
variant="primary",
|
215 |
-
interactive=False
|
216 |
-
)
|
217 |
-
|
218 |
-
self.components["pause_resume_btn"] = gr.Button(
|
219 |
-
"Resume Training",
|
220 |
-
variant="secondary",
|
221 |
-
interactive=False,
|
222 |
-
visible=False
|
223 |
-
)
|
224 |
-
|
225 |
-
# Add delete checkpoints button
|
226 |
-
self.components["delete_checkpoints_btn"] = gr.Button(
|
227 |
-
"Delete All Checkpoints",
|
228 |
-
variant="stop",
|
229 |
-
interactive=has_checkpoints
|
230 |
-
)
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
interactive=False,
|
253 |
-
lines=
|
254 |
-
|
255 |
-
autoscroll=True
|
256 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
return tab
|
259 |
|
@@ -963,7 +964,7 @@ class TrainTab(BaseTab):
|
|
963 |
|
964 |
# Create button updates
|
965 |
start_btn = gr.Button(
|
966 |
-
value="Start new training",
|
967 |
interactive=not is_training,
|
968 |
variant="primary" if not is_training else "secondary"
|
969 |
)
|
|
|
44 |
with gr.Row():
|
45 |
with gr.Column():
|
46 |
with gr.Row():
|
47 |
+
self.components["train_title"] = gr.Markdown("## 0 files in the training dataset")
|
48 |
|
49 |
with gr.Row():
|
50 |
with gr.Column():
|
|
|
181 |
|
182 |
with gr.Row():
|
183 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
with gr.Row():
|
186 |
+
with gr.Column():
|
187 |
+
# Add description of the training buttons
|
188 |
+
self.components["training_buttons_info"] = gr.Markdown("""
|
189 |
+
## ⚗️ Train your model on your dataset
|
190 |
+
- **Start new training**: Begins training from scratch (clears previous checkpoints)
|
191 |
+
- **Start from latest checkpoint**: Continues training from the most recent checkpoint
|
192 |
+
""")
|
193 |
+
|
194 |
+
with gr.Row():
|
195 |
+
# Check for existing checkpoints to determine button text
|
196 |
+
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
|
197 |
+
has_checkpoints = len(checkpoints) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
+
self.components["start_btn"] = gr.Button(
|
200 |
+
"🚀 Start new training",
|
201 |
+
variant="primary",
|
202 |
+
interactive=not ASK_USER_TO_DUPLICATE_SPACE
|
203 |
+
)
|
204 |
+
|
205 |
+
# Add new button for continuing from checkpoint
|
206 |
+
self.components["resume_btn"] = gr.Button(
|
207 |
+
"🛰️ Start from latest checkpoint",
|
208 |
+
variant="primary",
|
209 |
+
interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
|
210 |
+
)
|
211 |
+
|
212 |
+
with gr.Row():
|
213 |
+
# Just use stop and pause buttons for now to ensure compatibility
|
214 |
+
self.components["stop_btn"] = gr.Button(
|
215 |
+
"Stop at Last Checkpoint",
|
216 |
+
variant="primary",
|
217 |
+
interactive=False
|
218 |
+
)
|
219 |
+
|
220 |
+
self.components["pause_resume_btn"] = gr.Button(
|
221 |
+
"Resume Training",
|
222 |
+
variant="secondary",
|
223 |
+
interactive=False,
|
224 |
+
visible=False
|
225 |
+
)
|
226 |
+
|
227 |
+
# Add delete checkpoints button
|
228 |
+
self.components["delete_checkpoints_btn"] = gr.Button(
|
229 |
+
"Delete All Checkpoints",
|
230 |
+
variant="stop",
|
231 |
+
interactive=has_checkpoints
|
232 |
+
)
|
233 |
+
|
234 |
+
with gr.Row():
|
235 |
+
with gr.Column():
|
236 |
+
self.components["status_box"] = gr.Textbox(
|
237 |
+
label="Training Status",
|
238 |
+
interactive=False,
|
239 |
+
lines=4
|
240 |
+
)
|
241 |
+
|
242 |
+
# Add new component for current task progress
|
243 |
+
self.components["current_task_box"] = gr.Textbox(
|
244 |
+
label="Current Task Progress",
|
245 |
interactive=False,
|
246 |
+
lines=3,
|
247 |
+
elem_id="current_task_display"
|
|
|
248 |
)
|
249 |
+
|
250 |
+
with gr.Accordion("Finetrainers output (or see app logs for more details)", open=False):
|
251 |
+
self.components["log_box"] = gr.TextArea(
|
252 |
+
#label="",
|
253 |
+
interactive=False,
|
254 |
+
lines=60,
|
255 |
+
max_lines=600,
|
256 |
+
autoscroll=True
|
257 |
+
)
|
258 |
|
259 |
return tab
|
260 |
|
|
|
964 |
|
965 |
# Create button updates
|
966 |
start_btn = gr.Button(
|
967 |
+
value="🚀 Start new training",
|
968 |
interactive=not is_training,
|
969 |
variant="primary" if not is_training else "secondary"
|
970 |
)
|