Spaces:
Running
Running
Commit
Β·
89bbef2
1
Parent(s):
b91a6aa
fixes for the dataset importer
Browse files- app.py +2 -2
- vms/ui/__init__.py +2 -2
- vms/ui/{video_trainer_ui.py β app_ui.py} +150 -68
- vms/ui/monitoring/services/__init__.py +5 -0
- vms/{services β ui/monitoring/services}/monitoring.py +0 -0
- vms/ui/monitoring/tabs/__init__.py +9 -0
- vms/{tabs/monitor_tab.py β ui/monitoring/tabs/general_tab.py} +7 -26
- vms/ui/monitoring/utils/__init__.py +7 -0
- vms/ui/monitoring/utils/get_folder_size.py +22 -0
- vms/ui/monitoring/utils/human_readable_size.py +10 -0
- vms/{services β ui/project/services}/__init__.py +0 -2
- vms/{services β ui/project/services}/captioning.py +2 -2
- vms/{services β ui/project/services}/importing/__init__.py +0 -0
- vms/{services β ui/project/services}/importing/file_upload.py +0 -0
- vms/{services β ui/project/services}/importing/hub_dataset.py +5 -5
- vms/{services β ui/project/services}/importing/import_service.py +4 -3
- vms/{services β ui/project/services}/importing/youtube.py +0 -0
- vms/{services β ui/project/services}/previewing.py +0 -0
- vms/{services β ui/project/services}/splitting.py +2 -2
- vms/{services β ui/project/services}/training.py +2 -2
- vms/{tabs β ui/project/tabs}/__init__.py +1 -5
- vms/{tabs β ui/project/tabs}/caption_tab.py +4 -6
- vms/{tabs β ui/project/tabs}/import_tab/__init__.py +0 -0
- vms/{tabs β ui/project/tabs}/import_tab/hub_tab.py +20 -12
- vms/{tabs β ui/project/tabs}/import_tab/import_tab.py +19 -13
- vms/{tabs β ui/project/tabs}/import_tab/upload_tab.py +51 -19
- vms/{tabs β ui/project/tabs}/import_tab/youtube_tab.py +52 -16
- vms/{tabs β ui/project/tabs}/manage_tab.py +53 -49
- vms/{tabs β ui/project/tabs}/preview_tab.py +4 -4
- vms/{tabs β ui/project/tabs}/split_tab.py +1 -1
- vms/{tabs β ui/project/tabs}/train_tab.py +2 -2
- vms/utils/__init__.py +4 -1
- vms/{tabs β utils}/base_tab.py +1 -1
app.py
CHANGED
@@ -15,7 +15,7 @@ from vms.config import (
|
|
15 |
HF_API_TOKEN
|
16 |
)
|
17 |
|
18 |
-
from vms.ui.
|
19 |
|
20 |
# Configure logging
|
21 |
logger = logging.getLogger(__name__)
|
@@ -37,7 +37,7 @@ To avoid overpaying for your space, you can configure the auto-sleep settings to
|
|
37 |
return app
|
38 |
|
39 |
# Create the main application UI
|
40 |
-
ui =
|
41 |
app = ui.create_ui()
|
42 |
|
43 |
return app
|
|
|
15 |
HF_API_TOKEN
|
16 |
)
|
17 |
|
18 |
+
from vms.ui.app_ui import AppUI
|
19 |
|
20 |
# Configure logging
|
21 |
logger = logging.getLogger(__name__)
|
|
|
37 |
return app
|
38 |
|
39 |
# Create the main application UI
|
40 |
+
ui = AppUI()
|
41 |
app = ui.create_ui()
|
42 |
|
43 |
return app
|
vms/ui/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
from .
|
2 |
|
3 |
__all__ = [
|
4 |
-
'
|
5 |
]
|
|
|
1 |
+
from .app_ui import AppUI
|
2 |
|
3 |
__all__ = [
|
4 |
+
'AppUI',
|
5 |
]
|
vms/ui/{video_trainer_ui.py β app_ui.py}
RENAMED
@@ -5,8 +5,7 @@ import logging
|
|
5 |
import asyncio
|
6 |
from typing import Any, Optional, Dict, List, Union, Tuple
|
7 |
|
8 |
-
from
|
9 |
-
from ..config import (
|
10 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
|
11 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
12 |
MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
|
@@ -22,13 +21,27 @@ from ..config import (
|
|
22 |
DEFAULT_NB_TRAINING_STEPS,
|
23 |
DEFAULT_NB_LR_WARMUP_STEPS
|
24 |
)
|
25 |
-
from
|
26 |
get_recommended_precomputation_items,
|
27 |
count_media_files,
|
28 |
format_media_title,
|
29 |
TrainingLogParser
|
30 |
)
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
logger = logging.getLogger(__name__)
|
34 |
logger.setLevel(logging.INFO)
|
@@ -36,22 +49,23 @@ logger.setLevel(logging.INFO)
|
|
36 |
httpx_logger = logging.getLogger('httpx')
|
37 |
httpx_logger.setLevel(logging.WARN)
|
38 |
|
39 |
-
class
|
40 |
def __init__(self):
|
41 |
"""Initialize services and tabs"""
|
42 |
-
#
|
43 |
self.training = TrainingService(self)
|
44 |
self.splitting = SplittingService()
|
45 |
self.importing = ImportingService()
|
46 |
self.captioning = CaptioningService()
|
47 |
-
self.monitoring = MonitoringService()
|
48 |
self.previewing = PreviewingService()
|
49 |
|
50 |
-
#
|
|
|
51 |
self.monitoring.start_monitoring()
|
52 |
|
53 |
# Recovery status from any interrupted training
|
54 |
recovery_result = self.training.recover_interrupted_training()
|
|
|
55 |
# Add null check for recovery_result
|
56 |
if recovery_result is None:
|
57 |
recovery_result = {"status": "unknown", "ui_updates": {}}
|
@@ -67,9 +81,13 @@ class VideoTrainerUI:
|
|
67 |
"recovery_result": recovery_result
|
68 |
}
|
69 |
|
70 |
-
# Initialize tabs dictionary
|
71 |
self.tabs = {}
|
72 |
-
self.
|
|
|
|
|
|
|
|
|
73 |
|
74 |
# Log recovery status
|
75 |
logger.info(f"Initialization complete. Recovery status: {self.recovery_status}")
|
@@ -99,89 +117,153 @@ class VideoTrainerUI:
|
|
99 |
logger.info(f"Added periodic callback {callback_fn.__name__} with interval {interval}s")
|
100 |
except Exception as e:
|
101 |
logger.error(f"Error adding periodic callback: {e}", exc_info=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
def create_ui(self):
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
-
# Create main tabs component
|
109 |
-
with gr.Tabs() as self.tabs_component:
|
110 |
-
# Initialize tab objects
|
111 |
-
self.tabs["import_tab"] = ImportTab(self)
|
112 |
-
self.tabs["split_tab"] = SplitTab(self)
|
113 |
-
self.tabs["caption_tab"] = CaptionTab(self)
|
114 |
-
self.tabs["train_tab"] = TrainTab(self)
|
115 |
-
self.tabs["monitor_tab"] = MonitorTab(self)
|
116 |
-
self.tabs["preview_tab"] = PreviewTab(self)
|
117 |
-
self.tabs["manage_tab"] = ManageTab(self)
|
118 |
-
|
119 |
-
# Create tab UI components
|
120 |
-
for tab_id, tab_obj in self.tabs.items():
|
121 |
-
tab_obj.create(self.tabs_component)
|
122 |
|
123 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
for tab_id, tab_obj in self.tabs.items():
|
125 |
tab_obj.connect_events()
|
126 |
|
127 |
# app-level timers for auto-refresh functionality
|
128 |
self._add_timers()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
# Initialize app state on load
|
131 |
app.load(
|
132 |
fn=self.initialize_app_state,
|
133 |
outputs=[
|
134 |
-
self.
|
135 |
-
self.
|
136 |
-
self.
|
137 |
-
self.
|
138 |
-
self.
|
139 |
-
self.
|
140 |
-
self.
|
141 |
-
self.
|
142 |
-
self.
|
143 |
-
self.
|
144 |
-
self.
|
145 |
-
self.
|
146 |
-
self.
|
147 |
-
self.
|
148 |
-
self.
|
149 |
-
self.
|
150 |
-
self.
|
151 |
-
self.
|
152 |
]
|
153 |
)
|
154 |
-
|
155 |
return app
|
156 |
-
|
157 |
def _add_timers(self):
|
158 |
"""Add auto-refresh timers to the UI"""
|
159 |
# Status update timer for text components (every 1 second)
|
160 |
status_timer = gr.Timer(value=1)
|
161 |
status_timer.tick(
|
162 |
-
fn=self.
|
163 |
outputs=[
|
164 |
-
self.
|
165 |
-
self.
|
166 |
-
self.
|
167 |
]
|
168 |
)
|
169 |
|
170 |
# Button update timer for button components (every 1 second)
|
171 |
button_timer = gr.Timer(value=1)
|
172 |
button_outputs = [
|
173 |
-
self.
|
174 |
-
self.
|
175 |
]
|
176 |
|
177 |
# Add delete_checkpoints_btn or pause_resume_btn as the third button
|
178 |
-
if "delete_checkpoints_btn" in self.
|
179 |
-
button_outputs.append(self.
|
180 |
-
elif "pause_resume_btn" in self.
|
181 |
-
button_outputs.append(self.
|
182 |
|
183 |
button_timer.tick(
|
184 |
-
fn=self.
|
185 |
outputs=button_outputs
|
186 |
)
|
187 |
|
@@ -190,8 +272,8 @@ class VideoTrainerUI:
|
|
190 |
dataset_timer.tick(
|
191 |
fn=self.refresh_dataset,
|
192 |
outputs=[
|
193 |
-
self.
|
194 |
-
self.
|
195 |
]
|
196 |
)
|
197 |
|
@@ -200,17 +282,17 @@ class VideoTrainerUI:
|
|
200 |
titles_timer.tick(
|
201 |
fn=self.update_titles,
|
202 |
outputs=[
|
203 |
-
self.
|
204 |
-
self.
|
205 |
-
self.
|
206 |
]
|
207 |
)
|
208 |
|
209 |
def initialize_app_state(self):
|
210 |
"""Initialize all app state in one function to ensure correct output count"""
|
211 |
# Get dataset info
|
212 |
-
video_list = self.
|
213 |
-
training_dataset = self.
|
214 |
|
215 |
# Get button states based on recovery status
|
216 |
button_states = self.get_initial_button_states()
|
@@ -474,8 +556,8 @@ class VideoTrainerUI:
|
|
474 |
|
475 |
def refresh_dataset(self):
|
476 |
"""Refresh all dynamic lists and training state"""
|
477 |
-
video_list = self.
|
478 |
-
training_dataset = self.
|
479 |
|
480 |
return (
|
481 |
video_list,
|
|
|
5 |
import asyncio
|
6 |
from typing import Any, Optional, Dict, List, Union, Tuple
|
7 |
|
8 |
+
from vms.config import (
|
|
|
9 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
|
10 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
11 |
MODEL_TYPES, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
|
|
|
21 |
DEFAULT_NB_TRAINING_STEPS,
|
22 |
DEFAULT_NB_LR_WARMUP_STEPS
|
23 |
)
|
24 |
+
from vms.utils import (
|
25 |
get_recommended_precomputation_items,
|
26 |
count_media_files,
|
27 |
format_media_title,
|
28 |
TrainingLogParser
|
29 |
)
|
30 |
+
|
31 |
+
from vms.ui.project.services import (
|
32 |
+
TrainingService, CaptioningService, SplittingService, ImportingService, PreviewingService
|
33 |
+
)
|
34 |
+
from vms.ui.project.tabs import (
|
35 |
+
ImportTab, SplitTab, CaptionTab, TrainTab, PreviewTab, ManageTab
|
36 |
+
)
|
37 |
+
|
38 |
+
from vms.ui.monitoring.services import (
|
39 |
+
MonitoringService
|
40 |
+
)
|
41 |
+
|
42 |
+
from vms.ui.monitoring.tabs import (
|
43 |
+
GeneralTab
|
44 |
+
)
|
45 |
|
46 |
logger = logging.getLogger(__name__)
|
47 |
logger.setLevel(logging.INFO)
|
|
|
49 |
httpx_logger = logging.getLogger('httpx')
|
50 |
httpx_logger.setLevel(logging.WARN)
|
51 |
|
52 |
+
class AppUI:
|
53 |
def __init__(self):
|
54 |
"""Initialize services and tabs"""
|
55 |
+
# Project view
|
56 |
self.training = TrainingService(self)
|
57 |
self.splitting = SplittingService()
|
58 |
self.importing = ImportingService()
|
59 |
self.captioning = CaptioningService()
|
|
|
60 |
self.previewing = PreviewingService()
|
61 |
|
62 |
+
# Monitoring view
|
63 |
+
self.monitoring = MonitoringService()
|
64 |
self.monitoring.start_monitoring()
|
65 |
|
66 |
# Recovery status from any interrupted training
|
67 |
recovery_result = self.training.recover_interrupted_training()
|
68 |
+
|
69 |
# Add null check for recovery_result
|
70 |
if recovery_result is None:
|
71 |
recovery_result = {"status": "unknown", "ui_updates": {}}
|
|
|
81 |
"recovery_result": recovery_result
|
82 |
}
|
83 |
|
84 |
+
# Initialize tabs dictionary
|
85 |
self.tabs = {}
|
86 |
+
self.project_tabs = {}
|
87 |
+
self.monitor_tabs = {}
|
88 |
+
self.main_tabs = None # Main tabbed interface
|
89 |
+
self.project_tabs_component = None # Project sub-tabs
|
90 |
+
self.monitor_tabs_component = None # Monitor sub-tabs
|
91 |
|
92 |
# Log recovery status
|
93 |
logger.info(f"Initialization complete. Recovery status: {self.recovery_status}")
|
|
|
117 |
logger.info(f"Added periodic callback {callback_fn.__name__} with interval {interval}s")
|
118 |
except Exception as e:
|
119 |
logger.error(f"Error adding periodic callback: {e}", exc_info=True)
|
120 |
+
|
121 |
+
def switch_to_tab(self, tab_index: int):
|
122 |
+
"""Switch to the specified tab index
|
123 |
+
|
124 |
+
Args:
|
125 |
+
tab_index: Index of the tab to select (0 for Project, 1 for Monitor)
|
126 |
|
127 |
+
Returns:
|
128 |
+
Tab selection dictionary for Gradio
|
129 |
+
"""
|
130 |
+
|
131 |
+
return gr.Tabs(selected=tab_index)
|
132 |
+
|
133 |
def create_ui(self):
|
134 |
+
self.components = {}
|
135 |
+
"""Create the main Gradio UI with tabbed navigation"""
|
136 |
+
with gr.Blocks(
|
137 |
+
title="ποΈ Video Model Studio",
|
138 |
+
|
139 |
+
# Let's hack Gradio!
|
140 |
+
css="#component-8 > .tab-wrapper{ display: none; }") as app:
|
141 |
+
self.app = app
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
+
# Main container with sidebar and tab area
|
145 |
+
with gr.Row():
|
146 |
+
# Sidebar for navigation
|
147 |
+
with gr.Sidebar(position="left", open=True):
|
148 |
+
gr.Markdown("# ποΈ Video Model Studio")
|
149 |
+
self.components["current_project_btn"] = gr.Button("Current Project", variant="primary")
|
150 |
+
self.components["system_monitoring_btn"] = gr.Button("System Monitoring")
|
151 |
+
|
152 |
+
# Main content area with tabs
|
153 |
+
with gr.Column():
|
154 |
+
# Main tabbed interface for switching between Project and Monitor views
|
155 |
+
with gr.Tabs() as main_tabs:
|
156 |
+
self.main_tabs = main_tabs
|
157 |
+
|
158 |
+
# Project View Tab
|
159 |
+
with gr.Tab("π Current Project", id=0) as project_view:
|
160 |
+
# Create project tabs
|
161 |
+
with gr.Tabs() as project_tabs:
|
162 |
+
# Store reference to project tabs component
|
163 |
+
self.project_tabs_component = project_tabs
|
164 |
+
|
165 |
+
# Initialize project tab objects
|
166 |
+
self.project_tabs["import_tab"] = ImportTab(self)
|
167 |
+
self.project_tabs["split_tab"] = SplitTab(self)
|
168 |
+
self.project_tabs["caption_tab"] = CaptionTab(self)
|
169 |
+
self.project_tabs["train_tab"] = TrainTab(self)
|
170 |
+
self.project_tabs["preview_tab"] = PreviewTab(self)
|
171 |
+
self.project_tabs["manage_tab"] = ManageTab(self)
|
172 |
+
|
173 |
+
# Create tab UI components for project
|
174 |
+
for tab_id, tab_obj in self.project_tabs.items():
|
175 |
+
tab_obj.create(project_tabs)
|
176 |
+
|
177 |
+
# Monitoring View Tab
|
178 |
+
with gr.Tab("π System Monitoring", id=1) as monitoring_view:
|
179 |
+
# Create monitoring tabs
|
180 |
+
with gr.Tabs() as monitoring_tabs:
|
181 |
+
# Store reference to monitoring tabs component
|
182 |
+
self.monitor_tabs_component = monitoring_tabs
|
183 |
+
|
184 |
+
# Initialize monitoring tab objects
|
185 |
+
self.monitor_tabs["general_tab"] = GeneralTab(self)
|
186 |
+
|
187 |
+
# Create tab UI components for monitoring
|
188 |
+
for tab_id, tab_obj in self.monitor_tabs.items():
|
189 |
+
tab_obj.create(monitoring_tabs)
|
190 |
+
|
191 |
+
# Combine all tabs into a single dictionary for event handling
|
192 |
+
self.tabs = {**self.project_tabs, **self.monitor_tabs}
|
193 |
+
|
194 |
+
# Connect event handlers for all tabs - this must happen AFTER all tabs are created
|
195 |
for tab_id, tab_obj in self.tabs.items():
|
196 |
tab_obj.connect_events()
|
197 |
|
198 |
# app-level timers for auto-refresh functionality
|
199 |
self._add_timers()
|
200 |
+
|
201 |
+
# Connect navigation events using tab switching
|
202 |
+
self.components["current_project_btn"].click(
|
203 |
+
fn=lambda: self.switch_to_tab(0),
|
204 |
+
outputs=[self.main_tabs],
|
205 |
+
)
|
206 |
+
|
207 |
+
self.components["system_monitoring_btn"].click(
|
208 |
+
fn=lambda: self.switch_to_tab(1),
|
209 |
+
outputs=[self.main_tabs],
|
210 |
+
)
|
211 |
|
212 |
# Initialize app state on load
|
213 |
app.load(
|
214 |
fn=self.initialize_app_state,
|
215 |
outputs=[
|
216 |
+
self.project_tabs["split_tab"].components["video_list"],
|
217 |
+
self.project_tabs["caption_tab"].components["training_dataset"],
|
218 |
+
self.project_tabs["train_tab"].components["start_btn"],
|
219 |
+
self.project_tabs["train_tab"].components["stop_btn"],
|
220 |
+
self.project_tabs["train_tab"].components["pause_resume_btn"],
|
221 |
+
self.project_tabs["train_tab"].components["training_preset"],
|
222 |
+
self.project_tabs["train_tab"].components["model_type"],
|
223 |
+
self.project_tabs["train_tab"].components["training_type"],
|
224 |
+
self.project_tabs["train_tab"].components["lora_rank"],
|
225 |
+
self.project_tabs["train_tab"].components["lora_alpha"],
|
226 |
+
self.project_tabs["train_tab"].components["train_steps"],
|
227 |
+
self.project_tabs["train_tab"].components["batch_size"],
|
228 |
+
self.project_tabs["train_tab"].components["learning_rate"],
|
229 |
+
self.project_tabs["train_tab"].components["save_iterations"],
|
230 |
+
self.project_tabs["train_tab"].components["current_task_box"],
|
231 |
+
self.project_tabs["train_tab"].components["num_gpus"],
|
232 |
+
self.project_tabs["train_tab"].components["precomputation_items"],
|
233 |
+
self.project_tabs["train_tab"].components["lr_warmup_steps"]
|
234 |
]
|
235 |
)
|
236 |
+
|
237 |
return app
|
238 |
+
|
239 |
def _add_timers(self):
|
240 |
"""Add auto-refresh timers to the UI"""
|
241 |
# Status update timer for text components (every 1 second)
|
242 |
status_timer = gr.Timer(value=1)
|
243 |
status_timer.tick(
|
244 |
+
fn=self.project_tabs["train_tab"].get_status_updates, # Use a new function that returns appropriate updates
|
245 |
outputs=[
|
246 |
+
self.project_tabs["train_tab"].components["status_box"],
|
247 |
+
self.project_tabs["train_tab"].components["log_box"],
|
248 |
+
self.project_tabs["train_tab"].components["current_task_box"] if "current_task_box" in self.project_tabs["train_tab"].components else None
|
249 |
]
|
250 |
)
|
251 |
|
252 |
# Button update timer for button components (every 1 second)
|
253 |
button_timer = gr.Timer(value=1)
|
254 |
button_outputs = [
|
255 |
+
self.project_tabs["train_tab"].components["start_btn"],
|
256 |
+
self.project_tabs["train_tab"].components["stop_btn"]
|
257 |
]
|
258 |
|
259 |
# Add delete_checkpoints_btn or pause_resume_btn as the third button
|
260 |
+
if "delete_checkpoints_btn" in self.project_tabs["train_tab"].components:
|
261 |
+
button_outputs.append(self.project_tabs["train_tab"].components["delete_checkpoints_btn"])
|
262 |
+
elif "pause_resume_btn" in self.project_tabs["train_tab"].components:
|
263 |
+
button_outputs.append(self.project_tabs["train_tab"].components["pause_resume_btn"])
|
264 |
|
265 |
button_timer.tick(
|
266 |
+
fn=self.project_tabs["train_tab"].get_button_updates, # Use a new function for button-specific updates
|
267 |
outputs=button_outputs
|
268 |
)
|
269 |
|
|
|
272 |
dataset_timer.tick(
|
273 |
fn=self.refresh_dataset,
|
274 |
outputs=[
|
275 |
+
self.project_tabs["split_tab"].components["video_list"],
|
276 |
+
self.project_tabs["caption_tab"].components["training_dataset"]
|
277 |
]
|
278 |
)
|
279 |
|
|
|
282 |
titles_timer.tick(
|
283 |
fn=self.update_titles,
|
284 |
outputs=[
|
285 |
+
self.project_tabs["split_tab"].components["split_title"],
|
286 |
+
self.project_tabs["caption_tab"].components["caption_title"],
|
287 |
+
self.project_tabs["train_tab"].components["train_title"]
|
288 |
]
|
289 |
)
|
290 |
|
291 |
def initialize_app_state(self):
|
292 |
"""Initialize all app state in one function to ensure correct output count"""
|
293 |
# Get dataset info
|
294 |
+
video_list = self.project_tabs["split_tab"].list_unprocessed_videos()
|
295 |
+
training_dataset = self.project_tabs["caption_tab"].list_training_files_to_caption()
|
296 |
|
297 |
# Get button states based on recovery status
|
298 |
button_states = self.get_initial_button_states()
|
|
|
556 |
|
557 |
def refresh_dataset(self):
|
558 |
"""Refresh all dynamic lists and training state"""
|
559 |
+
video_list = self.project_tabs["split_tab"].list_unprocessed_videos()
|
560 |
+
training_dataset = self.project_tabs["caption_tab"].list_training_files_to_caption()
|
561 |
|
562 |
return (
|
563 |
video_list,
|
vms/ui/monitoring/services/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .monitoring import MonitoringService
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
'MonitoringService',
|
5 |
+
]
|
vms/{services β ui/monitoring/services}/monitoring.py
RENAMED
File without changes
|
vms/ui/monitoring/tabs/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Tab components for the "monitor" view
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .general_tab import GeneralTab
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'GeneralTab'
|
9 |
+
]
|
vms/{tabs/monitor_tab.py β ui/monitoring/tabs/general_tab.py}
RENAMED
@@ -12,39 +12,20 @@ import psutil
|
|
12 |
from typing import Dict, Any, List, Optional, Tuple
|
13 |
from datetime import datetime, timedelta
|
14 |
|
15 |
-
from .base_tab import BaseTab
|
16 |
-
from
|
|
|
17 |
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
20 |
-
def get_folder_size(path):
|
21 |
-
"""Calculate the total size of a folder in bytes"""
|
22 |
-
total_size = 0
|
23 |
-
for dirpath, dirnames, filenames in os.walk(path):
|
24 |
-
for filename in filenames:
|
25 |
-
file_path = os.path.join(dirpath, filename)
|
26 |
-
if not os.path.islink(file_path): # Skip symlinks
|
27 |
-
total_size += os.path.getsize(file_path)
|
28 |
-
return total_size
|
29 |
|
30 |
-
|
31 |
-
"""
|
32 |
-
if size_bytes == 0:
|
33 |
-
return "0 B"
|
34 |
-
size_names = ("B", "KB", "MB", "GB", "TB", "PB")
|
35 |
-
i = 0
|
36 |
-
while size_bytes >= 1024 and i < len(size_names) - 1:
|
37 |
-
size_bytes /= 1024
|
38 |
-
i += 1
|
39 |
-
return f"{size_bytes:.2f} {size_names[i]}"
|
40 |
-
|
41 |
-
class MonitorTab(BaseTab):
|
42 |
-
"""Monitor tab for system resource monitoring"""
|
43 |
|
44 |
def __init__(self, app_state):
|
45 |
super().__init__(app_state)
|
46 |
-
self.id = "
|
47 |
-
self.title = "
|
48 |
self.refresh_interval = 8
|
49 |
|
50 |
def create(self, parent=None) -> gr.TabItem:
|
|
|
12 |
from typing import Dict, Any, List, Optional, Tuple
|
13 |
from datetime import datetime, timedelta
|
14 |
|
15 |
+
from vms.utils.base_tab import BaseTab
|
16 |
+
from vms.config import STORAGE_PATH
|
17 |
+
from vms.ui.monitoring.utils import get_folder_size, human_readable_size
|
18 |
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
class GeneralTab(BaseTab):
|
23 |
+
"""Monitor tab for general system resource monitoring"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def __init__(self, app_state):
|
26 |
super().__init__(app_state)
|
27 |
+
self.id = "General_tab"
|
28 |
+
self.title = "General stats"
|
29 |
self.refresh_interval = 8
|
30 |
|
31 |
def create(self, parent=None) -> gr.TabItem:
|
vms/ui/monitoring/utils/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .get_folder_size import get_folder_size
|
2 |
+
from .human_readable_size import human_readable_size
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
'get_folder_size',
|
6 |
+
'human_readable_size',
|
7 |
+
]
|
vms/ui/monitoring/utils/get_folder_size.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
def get_folder_size(path):
|
4 |
+
"""Calculate the total size of a folder in bytes"""
|
5 |
+
total_size = 0
|
6 |
+
for dirpath, dirnames, filenames in os.walk(path):
|
7 |
+
for filename in filenames:
|
8 |
+
file_path = os.path.join(dirpath, filename)
|
9 |
+
if not os.path.islink(file_path): # Skip symlinks
|
10 |
+
total_size += os.path.getsize(file_path)
|
11 |
+
return total_size
|
12 |
+
|
13 |
+
def human_readable_size(size_bytes):
|
14 |
+
"""Convert a size in bytes to a human-readable string"""
|
15 |
+
if size_bytes == 0:
|
16 |
+
return "0 B"
|
17 |
+
size_names = ("B", "KB", "MB", "GB", "TB", "PB")
|
18 |
+
i = 0
|
19 |
+
while size_bytes >= 1024 and i < len(size_names) - 1:
|
20 |
+
size_bytes /= 1024
|
21 |
+
i += 1
|
22 |
+
return f"{size_bytes:.2f} {size_names[i]}"
|
vms/ui/monitoring/utils/human_readable_size.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def human_readable_size(size_bytes):
|
2 |
+
"""Convert a size in bytes to a human-readable string"""
|
3 |
+
if size_bytes == 0:
|
4 |
+
return "0 B"
|
5 |
+
size_names = ("B", "KB", "MB", "GB", "TB", "PB")
|
6 |
+
i = 0
|
7 |
+
while size_bytes >= 1024 and i < len(size_names) - 1:
|
8 |
+
size_bytes /= 1024
|
9 |
+
i += 1
|
10 |
+
return f"{size_bytes:.2f} {size_names[i]}"
|
vms/{services β ui/project/services}/__init__.py
RENAMED
@@ -1,6 +1,5 @@
|
|
1 |
from .captioning import CaptioningProgress, CaptioningService
|
2 |
from .importing import ImportingService
|
3 |
-
from .monitoring import MonitoringService
|
4 |
from .splitting import SplittingService
|
5 |
from .previewing import PreviewingService
|
6 |
from .training import TrainingService
|
@@ -9,7 +8,6 @@ __all__ = [
|
|
9 |
'CaptioningProgress',
|
10 |
'CaptioningService',
|
11 |
'ImportingService',
|
12 |
-
'MonitoringService',
|
13 |
'SplittingService',
|
14 |
'PreviewingService',
|
15 |
'TrainingService',
|
|
|
1 |
from .captioning import CaptioningProgress, CaptioningService
|
2 |
from .importing import ImportingService
|
|
|
3 |
from .splitting import SplittingService
|
4 |
from .previewing import PreviewingService
|
5 |
from .training import TrainingService
|
|
|
8 |
'CaptioningProgress',
|
9 |
'CaptioningService',
|
10 |
'ImportingService',
|
|
|
11 |
'SplittingService',
|
12 |
'PreviewingService',
|
13 |
'TrainingService',
|
vms/{services β ui/project/services}/captioning.py
RENAMED
@@ -17,8 +17,8 @@ from llava.mm_utils import tokenizer_image_token
|
|
17 |
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
18 |
from llava.conversation import conv_templates, SeparatorStyle
|
19 |
|
20 |
-
from
|
21 |
-
from
|
22 |
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
|
|
17 |
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
18 |
from llava.conversation import conv_templates, SeparatorStyle
|
19 |
|
20 |
+
from vms.config import TRAINING_VIDEOS_PATH, STAGING_PATH, PRELOAD_CAPTIONING_MODEL, CAPTIONING_MODEL, USE_MOCK_CAPTIONING_MODEL, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX
|
21 |
+
from vms.utils import extract_scene_info, is_image_file, is_video_file, copy_files_to_training_dir, prepare_finetrainers_dataset
|
22 |
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
vms/{services β ui/project/services}/importing/__init__.py
RENAMED
File without changes
|
vms/{services β ui/project/services}/importing/file_upload.py
RENAMED
File without changes
|
vms/{services β ui/project/services}/importing/hub_dataset.py
RENAMED
@@ -113,7 +113,7 @@ class HubDatasetBrowser:
|
|
113 |
dataset_info = self.hf_api.dataset_info(dataset_id)
|
114 |
|
115 |
# Format the information for display
|
116 |
-
info_text = f"
|
117 |
|
118 |
# Add description if available (with safer access)
|
119 |
card_data = getattr(dataset_info, "card_data", None)
|
@@ -126,11 +126,11 @@ class HubDatasetBrowser:
|
|
126 |
info_text += f"{description[:500]}{'...' if len(description) > 500 else ''}\n\n"
|
127 |
|
128 |
# Add basic stats (with safer access)
|
129 |
-
downloads = getattr(dataset_info, 'downloads', None)
|
130 |
-
info_text += f"## Downloads: {downloads if downloads is not None else 'N/A'}\n"
|
131 |
|
132 |
-
last_modified = getattr(dataset_info, 'last_modified', None)
|
133 |
-
info_text += f"## Last modified: {last_modified if last_modified is not None else 'N/A'}\n"
|
134 |
|
135 |
# Group files by type
|
136 |
file_groups = {
|
|
|
113 |
dataset_info = self.hf_api.dataset_info(dataset_id)
|
114 |
|
115 |
# Format the information for display
|
116 |
+
info_text = f"### {dataset_info.id}\n\n"
|
117 |
|
118 |
# Add description if available (with safer access)
|
119 |
card_data = getattr(dataset_info, "card_data", None)
|
|
|
126 |
info_text += f"{description[:500]}{'...' if len(description) > 500 else ''}\n\n"
|
127 |
|
128 |
# Add basic stats (with safer access)
|
129 |
+
#downloads = getattr(dataset_info, 'downloads', None)
|
130 |
+
#info_text += f"## Downloads: {downloads if downloads is not None else 'N/A'}\n"
|
131 |
|
132 |
+
#last_modified = getattr(dataset_info, 'last_modified', None)
|
133 |
+
#info_text += f"## Last modified: {last_modified if last_modified is not None else 'N/A'}\n"
|
134 |
|
135 |
# Group files by type
|
136 |
file_groups = {
|
vms/{services β ui/project/services}/importing/import_service.py
RENAMED
@@ -10,11 +10,12 @@ import gradio as gr
|
|
10 |
|
11 |
from huggingface_hub import HfApi
|
12 |
|
13 |
-
from .file_upload import FileUploadHandler
|
14 |
-
from .youtube import YouTubeDownloader
|
15 |
-
from .hub_dataset import HubDatasetBrowser
|
16 |
from vms.config import HF_API_TOKEN
|
17 |
|
|
|
|
|
|
|
|
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
20 |
class ImportingService:
|
|
|
10 |
|
11 |
from huggingface_hub import HfApi
|
12 |
|
|
|
|
|
|
|
13 |
from vms.config import HF_API_TOKEN
|
14 |
|
15 |
+
from vms.ui.project.services.importing.file_upload import FileUploadHandler
|
16 |
+
from vms.ui.project.services.importing.youtube import YouTubeDownloader
|
17 |
+
from vms.ui.project.services.importing.hub_dataset import HubDatasetBrowser
|
18 |
+
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
class ImportingService:
|
vms/{services β ui/project/services}/importing/youtube.py
RENAMED
File without changes
|
vms/{services β ui/project/services}/previewing.py
RENAMED
File without changes
|
vms/{services β ui/project/services}/splitting.py
RENAMED
@@ -12,8 +12,8 @@ import gradio as gr
|
|
12 |
from scenedetect import detect, ContentDetector, SceneManager, open_video
|
13 |
from scenedetect.video_splitter import split_video_ffmpeg
|
14 |
|
15 |
-
from
|
16 |
-
from
|
17 |
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
|
|
12 |
from scenedetect import detect, ContentDetector, SceneManager, open_video
|
13 |
from scenedetect.video_splitter import split_video_ffmpeg
|
14 |
|
15 |
+
from vms.config import TRAINING_PATH, STORAGE_PATH, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
|
16 |
+
from vms.utils import remove_black_bars, extract_scene_info, is_video_file, is_image_file, add_prefix_to_caption
|
17 |
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
vms/{services β ui/project/services}/training.py
RENAMED
@@ -20,7 +20,7 @@ from typing import Any, Optional, Dict, List, Union, Tuple
|
|
20 |
|
21 |
from huggingface_hub import upload_folder, create_repo
|
22 |
|
23 |
-
from
|
24 |
TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
|
25 |
STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
|
26 |
MODEL_TYPES, TRAINING_TYPES,
|
@@ -39,7 +39,7 @@ from ..config import (
|
|
39 |
DEFAULT_NB_TRAINING_STEPS,
|
40 |
DEFAULT_NB_LR_WARMUP_STEPS
|
41 |
)
|
42 |
-
from
|
43 |
get_available_gpu_count,
|
44 |
make_archive,
|
45 |
parse_training_log,
|
|
|
20 |
|
21 |
from huggingface_hub import upload_folder, create_repo
|
22 |
|
23 |
+
from vms.config import (
|
24 |
TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH,
|
25 |
STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN,
|
26 |
MODEL_TYPES, TRAINING_TYPES,
|
|
|
39 |
DEFAULT_NB_TRAINING_STEPS,
|
40 |
DEFAULT_NB_LR_WARMUP_STEPS
|
41 |
)
|
42 |
+
from vms.utils import (
|
43 |
get_available_gpu_count,
|
44 |
make_archive,
|
45 |
parse_training_log,
|
vms/{tabs β ui/project/tabs}/__init__.py
RENAMED
@@ -1,23 +1,19 @@
|
|
1 |
"""
|
2 |
-
Tab components for
|
3 |
"""
|
4 |
|
5 |
-
from .base_tab import BaseTab
|
6 |
from .import_tab import ImportTab
|
7 |
from .split_tab import SplitTab
|
8 |
from .caption_tab import CaptionTab
|
9 |
from .train_tab import TrainTab
|
10 |
-
from .monitor_tab import MonitorTab
|
11 |
from .preview_tab import PreviewTab
|
12 |
from .manage_tab import ManageTab
|
13 |
|
14 |
__all__ = [
|
15 |
-
'BaseTab',
|
16 |
'ImportTab',
|
17 |
'SplitTab',
|
18 |
'CaptionTab',
|
19 |
'TrainTab',
|
20 |
-
'MonitorTab',
|
21 |
'PreviewTab',
|
22 |
'ManageTab'
|
23 |
]
|
|
|
1 |
"""
|
2 |
+
Tab components for the "project" view
|
3 |
"""
|
4 |
|
|
|
5 |
from .import_tab import ImportTab
|
6 |
from .split_tab import SplitTab
|
7 |
from .caption_tab import CaptionTab
|
8 |
from .train_tab import TrainTab
|
|
|
9 |
from .preview_tab import PreviewTab
|
10 |
from .manage_tab import ManageTab
|
11 |
|
12 |
__all__ = [
|
|
|
13 |
'ImportTab',
|
14 |
'SplitTab',
|
15 |
'CaptionTab',
|
16 |
'TrainTab',
|
|
|
17 |
'PreviewTab',
|
18 |
'ManageTab'
|
19 |
]
|
vms/{tabs β ui/project/tabs}/caption_tab.py
RENAMED
@@ -8,10 +8,10 @@ import asyncio
|
|
8 |
import traceback
|
9 |
from typing import Dict, Any, List, Optional, AsyncGenerator, Tuple
|
10 |
from pathlib import Path
|
|
|
11 |
|
12 |
-
from .
|
13 |
-
from
|
14 |
-
from ..utils import is_image_file, is_video_file, copy_files_to_training_dir
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
@@ -533,9 +533,7 @@ class CaptionTab(BaseTab):
|
|
533 |
Returns:
|
534 |
Dict with preview content for each preview component
|
535 |
"""
|
536 |
-
|
537 |
-
from ..config import TRAINING_VIDEOS_PATH
|
538 |
-
|
539 |
if not selected_text or "Caption:" in selected_text:
|
540 |
return {
|
541 |
"video": None,
|
|
|
8 |
import traceback
|
9 |
from typing import Dict, Any, List, Optional, AsyncGenerator, Tuple
|
10 |
from pathlib import Path
|
11 |
+
import mimetypes
|
12 |
|
13 |
+
from vms.utils import BaseTab, is_image_file, is_video_file, copy_files_to_training_dir
|
14 |
+
from vms.config import DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, DEFAULT_PROMPT_PREFIX, STAGING_PATH, TRAINING_VIDEOS_PATH
|
|
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
|
|
533 |
Returns:
|
534 |
Dict with preview content for each preview component
|
535 |
"""
|
536 |
+
|
|
|
|
|
537 |
if not selected_text or "Caption:" in selected_text:
|
538 |
return {
|
539 |
"video": None,
|
vms/{tabs β ui/project/tabs}/import_tab/__init__.py
RENAMED
File without changes
|
vms/{tabs β ui/project/tabs}/import_tab/hub_tab.py
RENAMED
@@ -10,7 +10,7 @@ import threading
|
|
10 |
from pathlib import Path
|
11 |
from typing import Dict, Any, List, Optional, Tuple
|
12 |
|
13 |
-
from
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
@@ -26,22 +26,32 @@ class HubTab(BaseTab):
|
|
26 |
def create(self, parent=None) -> gr.Tab:
|
27 |
"""Create the Hub tab UI components"""
|
28 |
with gr.Tab(self.title, id=self.id) as tab:
|
|
|
29 |
with gr.Column():
|
30 |
with gr.Row():
|
31 |
-
gr.Markdown("## Import from
|
32 |
|
33 |
with gr.Row():
|
34 |
-
gr.
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
with gr.Row():
|
37 |
-
self.components["
|
38 |
-
|
39 |
-
|
|
|
40 |
)
|
41 |
|
42 |
-
with gr.Row():
|
43 |
-
self.components["dataset_search_btn"] = gr.Button("Search Datasets", variant="primary")
|
44 |
-
|
45 |
# Dataset browser results section
|
46 |
with gr.Row(visible=False) as dataset_results_row:
|
47 |
self.components["dataset_results_row"] = dataset_results_row
|
@@ -66,8 +76,6 @@ class HubTab(BaseTab):
|
|
66 |
with gr.Column(visible=False) as files_section:
|
67 |
self.components["files_section"] = files_section
|
68 |
|
69 |
-
gr.Markdown("## Files:")
|
70 |
-
|
71 |
# Video files row (appears if videos are present)
|
72 |
with gr.Row() as video_files_row:
|
73 |
self.components["video_files_row"] = video_files_row
|
|
|
10 |
from pathlib import Path
|
11 |
from typing import Dict, Any, List, Optional, Tuple
|
12 |
|
13 |
+
from vms.utils import BaseTab
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
|
|
26 |
def create(self, parent=None) -> gr.Tab:
|
27 |
"""Create the Hub tab UI components"""
|
28 |
with gr.Tab(self.title, id=self.id) as tab:
|
29 |
+
|
30 |
with gr.Column():
|
31 |
with gr.Row():
|
32 |
+
gr.Markdown("## Import a dataset from Hugging Face")
|
33 |
|
34 |
with gr.Row():
|
35 |
+
with gr.Column():
|
36 |
+
with gr.Row():
|
37 |
+
gr.Markdown("You can use any dataset containing video files (.mp4) with optional captions (same names but in .txt format)")
|
38 |
+
|
39 |
+
with gr.Row():
|
40 |
+
gr.Markdown("You can also use a dataset containing WebDataset shards (.tar files).")
|
41 |
+
|
42 |
+
with gr.Column():
|
43 |
+
self.components["dataset_search"] = gr.Textbox(
|
44 |
+
label="Search Hugging Face Datasets (MP4, WebDataset)",
|
45 |
+
placeholder="video datasets eg. cakeify, disney, rickroll.."
|
46 |
+
)
|
47 |
+
|
48 |
with gr.Row():
|
49 |
+
self.components["dataset_search_btn"] = gr.Button(
|
50 |
+
"Search Datasets",
|
51 |
+
variant="primary",
|
52 |
+
#size="md"
|
53 |
)
|
54 |
|
|
|
|
|
|
|
55 |
# Dataset browser results section
|
56 |
with gr.Row(visible=False) as dataset_results_row:
|
57 |
self.components["dataset_results_row"] = dataset_results_row
|
|
|
76 |
with gr.Column(visible=False) as files_section:
|
77 |
self.components["files_section"] = files_section
|
78 |
|
|
|
|
|
79 |
# Video files row (appears if videos are present)
|
80 |
with gr.Row() as video_files_row:
|
81 |
self.components["video_files_row"] = video_files_row
|
vms/{tabs β ui/project/tabs}/import_tab/import_tab.py
RENAMED
@@ -9,10 +9,10 @@ import threading
|
|
9 |
from pathlib import Path
|
10 |
from typing import Dict, Any, List, Optional, Tuple
|
11 |
|
12 |
-
from
|
13 |
-
from .upload_tab import UploadTab
|
14 |
-
from .youtube_tab import YouTubeTab
|
15 |
-
from .hub_tab import HubTab
|
16 |
|
17 |
from vms.config import (
|
18 |
VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
|
@@ -28,7 +28,8 @@ class ImportTab(BaseTab):
|
|
28 |
super().__init__(app_state)
|
29 |
self.id = "import_tab"
|
30 |
self.title = "1οΈβ£ Import"
|
31 |
-
|
|
|
32 |
self.upload_tab = UploadTab(app_state)
|
33 |
self.youtube_tab = YouTubeTab(app_state)
|
34 |
self.hub_tab = HubTab(app_state)
|
@@ -53,7 +54,11 @@ class ImportTab(BaseTab):
|
|
53 |
visible=True,
|
54 |
)
|
55 |
|
56 |
-
# Create
|
|
|
|
|
|
|
|
|
57 |
with gr.Tabs() as import_tabs:
|
58 |
# Create each sub-tab
|
59 |
self.upload_tab.create(import_tabs)
|
@@ -64,19 +69,20 @@ class ImportTab(BaseTab):
|
|
64 |
self.components["upload_tab"] = self.upload_tab
|
65 |
self.components["youtube_tab"] = self.youtube_tab
|
66 |
self.components["hub_tab"] = self.hub_tab
|
67 |
-
|
68 |
-
with gr.Row():
|
69 |
-
self.components["import_status"] = gr.Textbox(label="Status", interactive=False)
|
70 |
|
71 |
return tab
|
72 |
|
73 |
def connect_events(self) -> None:
|
74 |
"""Connect event handlers to UI components"""
|
75 |
-
# Set shared components from parent tab to sub-tabs
|
76 |
for subtab in [self.upload_tab, self.youtube_tab, self.hub_tab]:
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
|
81 |
# Then connect events for each sub-tab
|
82 |
self.upload_tab.connect_events()
|
|
|
9 |
from pathlib import Path
|
10 |
from typing import Dict, Any, List, Optional, Tuple
|
11 |
|
12 |
+
from vms.utils import BaseTab
|
13 |
+
from vms.ui.project.tabs.import_tab.upload_tab import UploadTab
|
14 |
+
from vms.ui.project.tabs.import_tab.youtube_tab import YouTubeTab
|
15 |
+
from vms.ui.project.tabs.import_tab.hub_tab import HubTab
|
16 |
|
17 |
from vms.config import (
|
18 |
VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
|
|
|
28 |
super().__init__(app_state)
|
29 |
self.id = "import_tab"
|
30 |
self.title = "1οΈβ£ Import"
|
31 |
+
|
32 |
+
# Initialize sub-tabs - these should be created first
|
33 |
self.upload_tab = UploadTab(app_state)
|
34 |
self.youtube_tab = YouTubeTab(app_state)
|
35 |
self.hub_tab = HubTab(app_state)
|
|
|
54 |
visible=True,
|
55 |
)
|
56 |
|
57 |
+
# Create the import status textbox before creating the sub-tabs
|
58 |
+
with gr.Row():
|
59 |
+
self.components["import_status"] = gr.Textbox(label="Status", interactive=False)
|
60 |
+
|
61 |
+
# Now create tabs for different import methods
|
62 |
with gr.Tabs() as import_tabs:
|
63 |
# Create each sub-tab
|
64 |
self.upload_tab.create(import_tabs)
|
|
|
69 |
self.components["upload_tab"] = self.upload_tab
|
70 |
self.components["youtube_tab"] = self.youtube_tab
|
71 |
self.components["hub_tab"] = self.hub_tab
|
|
|
|
|
|
|
72 |
|
73 |
return tab
|
74 |
|
75 |
def connect_events(self) -> None:
|
76 |
"""Connect event handlers to UI components"""
|
77 |
+
# Set shared components from parent tab to sub-tabs before connecting events
|
78 |
for subtab in [self.upload_tab, self.youtube_tab, self.hub_tab]:
|
79 |
+
# Ensure these components exist in the parent before sharing them
|
80 |
+
if "import_status" in self.components:
|
81 |
+
subtab.components["import_status"] = self.components["import_status"]
|
82 |
+
if "enable_automatic_video_split" in self.components:
|
83 |
+
subtab.components["enable_automatic_video_split"] = self.components["enable_automatic_video_split"]
|
84 |
+
if "enable_automatic_content_captioning" in self.components:
|
85 |
+
subtab.components["enable_automatic_content_captioning"] = self.components["enable_automatic_content_captioning"]
|
86 |
|
87 |
# Then connect events for each sub-tab
|
88 |
self.upload_tab.connect_events()
|
vms/{tabs β ui/project/tabs}/import_tab/upload_tab.py
RENAMED
@@ -8,7 +8,7 @@ import logging
|
|
8 |
from pathlib import Path
|
9 |
from typing import Dict, Any, Optional
|
10 |
|
11 |
-
from
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
@@ -19,6 +19,12 @@ class UploadTab(BaseTab):
|
|
19 |
super().__init__(app_state)
|
20 |
self.id = "upload_tab"
|
21 |
self.title = "Manual Upload"
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def create(self, parent=None) -> gr.Tab:
|
24 |
"""Create the Upload tab UI components"""
|
@@ -51,24 +57,50 @@ class UploadTab(BaseTab):
|
|
51 |
|
52 |
def connect_events(self) -> None:
|
53 |
"""Connect event handlers to UI components"""
|
|
|
|
|
|
|
|
|
|
|
54 |
# File upload event
|
55 |
-
self.components["files"].upload(
|
56 |
fn=lambda x: self.app.importing.process_uploaded_files(x),
|
57 |
inputs=[self.components["files"]],
|
58 |
-
outputs=[self.components["import_status"]]
|
59 |
-
)
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
self.app.
|
69 |
-
self.app.tabs["split_tab"].components["
|
70 |
-
self.app.tabs["split_tab"].components["
|
71 |
-
self.app.tabs["
|
72 |
-
self.app.tabs["
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from pathlib import Path
|
9 |
from typing import Dict, Any, Optional
|
10 |
|
11 |
+
from vms.utils import BaseTab
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
|
|
19 |
super().__init__(app_state)
|
20 |
self.id = "upload_tab"
|
21 |
self.title = "Manual Upload"
|
22 |
+
# Initialize the components dictionary with None values for expected shared components
|
23 |
+
if "components" not in self.__dict__:
|
24 |
+
self.components = {}
|
25 |
+
self.components["import_status"] = None
|
26 |
+
self.components["enable_automatic_video_split"] = None
|
27 |
+
self.components["enable_automatic_content_captioning"] = None
|
28 |
|
29 |
def create(self, parent=None) -> gr.Tab:
|
30 |
"""Create the Upload tab UI components"""
|
|
|
57 |
|
58 |
def connect_events(self) -> None:
|
59 |
"""Connect event handlers to UI components"""
|
60 |
+
# Check if required shared components exist before connecting events
|
61 |
+
if not self.components.get("import_status"):
|
62 |
+
logger.warning("import_status component is not set in UploadTab")
|
63 |
+
return
|
64 |
+
|
65 |
# File upload event
|
66 |
+
upload_event = self.components["files"].upload(
|
67 |
fn=lambda x: self.app.importing.process_uploaded_files(x),
|
68 |
inputs=[self.components["files"]],
|
69 |
+
outputs=[self.components["import_status"]]
|
70 |
+
)
|
71 |
+
|
72 |
+
# Only add success handler if all required components exist
|
73 |
+
if hasattr(self.app.tabs, "import_tab") and hasattr(self.app.tabs, "split_tab") and \
|
74 |
+
hasattr(self.app.tabs, "caption_tab") and hasattr(self.app.tabs, "train_tab"):
|
75 |
+
|
76 |
+
# Get required components for success handler
|
77 |
+
try:
|
78 |
+
# If the components are missing, this will raise an AttributeError
|
79 |
+
tabs_component = self.app.tabs_component
|
80 |
+
video_list = self.app.tabs["split_tab"].components["video_list"]
|
81 |
+
detect_status = self.app.tabs["split_tab"].components["detect_status"]
|
82 |
+
split_title = self.app.tabs["split_tab"].components["split_title"]
|
83 |
+
caption_title = self.app.tabs["caption_tab"].components["caption_title"]
|
84 |
+
train_title = self.app.tabs["train_tab"].components["train_title"]
|
85 |
+
custom_prompt_prefix = self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
86 |
+
|
87 |
+
# Add success handler
|
88 |
+
upload_event.success(
|
89 |
+
fn=self.app.tabs["import_tab"].update_titles_after_import,
|
90 |
+
inputs=[
|
91 |
+
self.components["enable_automatic_video_split"],
|
92 |
+
self.components["enable_automatic_content_captioning"],
|
93 |
+
custom_prompt_prefix
|
94 |
+
],
|
95 |
+
outputs=[
|
96 |
+
tabs_component,
|
97 |
+
video_list,
|
98 |
+
detect_status,
|
99 |
+
split_title,
|
100 |
+
caption_title,
|
101 |
+
train_title
|
102 |
+
]
|
103 |
+
)
|
104 |
+
except (AttributeError, KeyError) as e:
|
105 |
+
logger.error(f"Error connecting event handlers in UploadTab: {str(e)}")
|
106 |
+
# Continue without the success handler
|
vms/{tabs β ui/project/tabs}/import_tab/youtube_tab.py
RENAMED
@@ -8,7 +8,7 @@ import logging
|
|
8 |
from pathlib import Path
|
9 |
from typing import Dict, Any, Optional
|
10 |
|
11 |
-
from
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
@@ -19,6 +19,12 @@ class YouTubeTab(BaseTab):
|
|
19 |
super().__init__(app_state)
|
20 |
self.id = "youtube_tab"
|
21 |
self.title = "Download from YouTube"
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def create(self, parent=None) -> gr.Tab:
|
24 |
"""Create the YouTube tab UI components"""
|
@@ -47,21 +53,51 @@ class YouTubeTab(BaseTab):
|
|
47 |
|
48 |
def connect_events(self) -> None:
|
49 |
"""Connect event handlers to UI components"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# YouTube download event
|
51 |
-
self.components["youtube_download_btn"].click(
|
52 |
fn=self.app.importing.download_youtube_video,
|
53 |
inputs=[self.components["youtube_url"]],
|
54 |
-
outputs=[self.components["import_status"]]
|
55 |
-
)
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from pathlib import Path
|
9 |
from typing import Dict, Any, Optional
|
10 |
|
11 |
+
from vms.utils import BaseTab
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
|
|
19 |
super().__init__(app_state)
|
20 |
self.id = "youtube_tab"
|
21 |
self.title = "Download from YouTube"
|
22 |
+
# Initialize components that will be shared from parent
|
23 |
+
if "components" not in self.__dict__:
|
24 |
+
self.components = {}
|
25 |
+
self.components["import_status"] = None
|
26 |
+
self.components["enable_automatic_video_split"] = None
|
27 |
+
self.components["enable_automatic_content_captioning"] = None
|
28 |
|
29 |
def create(self, parent=None) -> gr.Tab:
|
30 |
"""Create the YouTube tab UI components"""
|
|
|
53 |
|
54 |
def connect_events(self) -> None:
|
55 |
"""Connect event handlers to UI components"""
|
56 |
+
# Check if required shared components exist before connecting events
|
57 |
+
if not self.components.get("import_status"):
|
58 |
+
logger.warning("import_status component is not set in YouTubeTab")
|
59 |
+
return
|
60 |
+
|
61 |
+
if not self.components.get("enable_automatic_video_split"):
|
62 |
+
logger.warning("enable_automatic_video_split component is not set in YouTubeTab")
|
63 |
+
return
|
64 |
+
|
65 |
+
if not self.components.get("enable_automatic_content_captioning"):
|
66 |
+
logger.warning("enable_automatic_content_captioning component is not set in YouTubeTab")
|
67 |
+
return
|
68 |
+
|
69 |
+
# Only try to access custom_prompt_prefix if the caption_tab exists
|
70 |
+
custom_prompt_prefix = None
|
71 |
+
try:
|
72 |
+
if hasattr(self.app.tabs, "caption_tab") and "custom_prompt_prefix" in self.app.tabs["caption_tab"].components:
|
73 |
+
custom_prompt_prefix = self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
74 |
+
except (AttributeError, KeyError):
|
75 |
+
logger.warning("Could not access custom_prompt_prefix component")
|
76 |
+
|
77 |
# YouTube download event
|
78 |
+
download_event = self.components["youtube_download_btn"].click(
|
79 |
fn=self.app.importing.download_youtube_video,
|
80 |
inputs=[self.components["youtube_url"]],
|
81 |
+
outputs=[self.components["import_status"]]
|
82 |
+
)
|
83 |
+
|
84 |
+
# Add success handler if all components exist
|
85 |
+
if hasattr(self.app.tabs, "import_tab") and custom_prompt_prefix is not None:
|
86 |
+
try:
|
87 |
+
# Add the success handler
|
88 |
+
download_event.success(
|
89 |
+
fn=self.app.tabs["import_tab"].on_import_success,
|
90 |
+
inputs=[
|
91 |
+
self.components["enable_automatic_video_split"],
|
92 |
+
self.components["enable_automatic_content_captioning"],
|
93 |
+
custom_prompt_prefix
|
94 |
+
],
|
95 |
+
outputs=[
|
96 |
+
self.app.tabs_component,
|
97 |
+
self.app.tabs["split_tab"].components["video_list"],
|
98 |
+
self.app.tabs["split_tab"].components["detect_status"]
|
99 |
+
]
|
100 |
+
)
|
101 |
+
except (AttributeError, KeyError) as e:
|
102 |
+
logger.error(f"Error connecting success handler in YouTubeTab: {str(e)}")
|
103 |
+
# Continue without the success handler
|
vms/{tabs β ui/project/tabs}/manage_tab.py
RENAMED
@@ -8,12 +8,11 @@ import shutil
|
|
8 |
from pathlib import Path
|
9 |
from typing import Dict, Any, List, Optional
|
10 |
|
11 |
-
from .
|
12 |
-
from
|
13 |
HF_API_TOKEN, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH,
|
14 |
TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, LOG_FILE_PATH
|
15 |
)
|
16 |
-
from ..utils import validate_model_repo
|
17 |
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
@@ -23,59 +22,64 @@ class ManageTab(BaseTab):
|
|
23 |
def __init__(self, app_state):
|
24 |
super().__init__(app_state)
|
25 |
self.id = "manage_tab"
|
26 |
-
self.title = "
|
27 |
|
28 |
def create(self, parent=None) -> gr.TabItem:
|
29 |
"""Create the Manage tab UI components"""
|
30 |
with gr.TabItem(self.title, id=self.id) as tab:
|
31 |
-
with gr.
|
32 |
-
with gr.
|
33 |
-
|
34 |
-
|
35 |
-
gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
size="lg"
|
67 |
-
)
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
return tab
|
81 |
|
|
|
8 |
from pathlib import Path
|
9 |
from typing import Dict, Any, List, Optional
|
10 |
|
11 |
+
from vms.utils import BaseTab, validate_model_repo
|
12 |
+
from vms.config import (
|
13 |
HF_API_TOKEN, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH,
|
14 |
TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, LOG_FILE_PATH
|
15 |
)
|
|
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
|
|
22 |
def __init__(self, app_state):
|
23 |
super().__init__(app_state)
|
24 |
self.id = "manage_tab"
|
25 |
+
self.title = "6οΈβ£ Storage"
|
26 |
|
27 |
def create(self, parent=None) -> gr.TabItem:
|
28 |
"""Create the Manage tab UI components"""
|
29 |
with gr.TabItem(self.title, id=self.id) as tab:
|
30 |
+
with gr.Row():
|
31 |
+
with gr.Column():
|
32 |
+
gr.Markdown("## Download your model")
|
33 |
+
gr.Markdown("There is currently a bug, you might have to click multiple times to trigger a download.")
|
|
|
34 |
|
35 |
+
with gr.Row():
|
36 |
+
self.components["download_dataset_btn"] = gr.DownloadButton(
|
37 |
+
"Download training dataset",
|
38 |
+
variant="secondary",
|
39 |
+
size="lg"
|
40 |
+
)
|
41 |
+
self.components["download_model_btn"] = gr.DownloadButton(
|
42 |
+
"Download model weights",
|
43 |
+
variant="secondary",
|
44 |
+
size="lg"
|
45 |
+
)
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column():
|
48 |
+
gr.Markdown("## Publish your model")
|
49 |
+
gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
|
50 |
|
51 |
+
with gr.Row():
|
52 |
+
with gr.Column():
|
53 |
+
self.components["repo_id"] = gr.Textbox(
|
54 |
+
label="HuggingFace Model Repository",
|
55 |
+
placeholder="username/model-name",
|
56 |
+
info="The repository will be created if it doesn't exist"
|
57 |
+
)
|
58 |
+
self.components["make_public"] = gr.Checkbox(
|
59 |
+
label="Check this to make your model public (ie. visible and downloadable by anyone)",
|
60 |
+
info="You model is private by default"
|
61 |
+
)
|
62 |
+
self.components["push_model_btn"] = gr.Button(
|
63 |
+
"Push my model"
|
64 |
+
)
|
|
|
|
|
65 |
|
66 |
+
with gr.Row():
|
67 |
+
with gr.Column():
|
68 |
+
gr.Markdown("## Delete your model")
|
69 |
+
gr.Markdown("If something went wrong, you can trigger a full reset (model shutdown + data destruction).")
|
70 |
+
gr.Markdown("Make sure you have made a backup first.")
|
71 |
+
gr.Markdown("If you are deleting because of a bug, remember you can use the Developer Mode on HF to inspect the working directory (in /data or .data)")
|
72 |
+
|
73 |
+
with gr.Row():
|
74 |
+
self.components["global_stop_btn"] = gr.Button(
|
75 |
+
"Stop everything and delete my data",
|
76 |
+
variant="stop"
|
77 |
+
)
|
78 |
+
self.components["global_status"] = gr.Textbox(
|
79 |
+
label="Global Status",
|
80 |
+
interactive=False,
|
81 |
+
visible=False
|
82 |
+
)
|
83 |
|
84 |
return tab
|
85 |
|
vms/{tabs β ui/project/tabs}/preview_tab.py
RENAMED
@@ -8,8 +8,8 @@ from pathlib import Path
|
|
8 |
from typing import Dict, Any, List, Optional, Tuple
|
9 |
import time
|
10 |
|
11 |
-
from .
|
12 |
-
from
|
13 |
MODEL_TYPES, DEFAULT_PROMPT_PREFIX
|
14 |
)
|
15 |
|
@@ -21,13 +21,13 @@ class PreviewTab(BaseTab):
|
|
21 |
def __init__(self, app_state):
|
22 |
super().__init__(app_state)
|
23 |
self.id = "preview_tab"
|
24 |
-
self.title = "
|
25 |
|
26 |
def create(self, parent=None) -> gr.TabItem:
|
27 |
"""Create the Preview tab UI components"""
|
28 |
with gr.TabItem(self.title, id=self.id) as tab:
|
29 |
with gr.Row():
|
30 |
-
gr.Markdown("##
|
31 |
|
32 |
with gr.Row():
|
33 |
with gr.Column(scale=2):
|
|
|
8 |
from typing import Dict, Any, List, Optional, Tuple
|
9 |
import time
|
10 |
|
11 |
+
from vms.utils import BaseTab
|
12 |
+
from vms.config import (
|
13 |
MODEL_TYPES, DEFAULT_PROMPT_PREFIX
|
14 |
)
|
15 |
|
|
|
21 |
def __init__(self, app_state):
|
22 |
super().__init__(app_state)
|
23 |
self.id = "preview_tab"
|
24 |
+
self.title = "5οΈβ£ Preview"
|
25 |
|
26 |
def create(self, parent=None) -> gr.TabItem:
|
27 |
"""Create the Preview tab UI components"""
|
28 |
with gr.TabItem(self.title, id=self.id) as tab:
|
29 |
with gr.Row():
|
30 |
+
gr.Markdown("## Preview your model")
|
31 |
|
32 |
with gr.Row():
|
33 |
with gr.Column(scale=2):
|
vms/{tabs β ui/project/tabs}/split_tab.py
RENAMED
@@ -6,7 +6,7 @@ import gradio as gr
|
|
6 |
import logging
|
7 |
from typing import Dict, Any, List, Optional
|
8 |
|
9 |
-
from .
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
|
|
6 |
import logging
|
7 |
from typing import Dict, Any, List, Optional
|
8 |
|
9 |
+
from vms.utils import BaseTab
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
vms/{tabs β ui/project/tabs}/train_tab.py
RENAMED
@@ -8,8 +8,8 @@ import os
|
|
8 |
from typing import Dict, Any, List, Optional, Tuple
|
9 |
from pathlib import Path
|
10 |
|
11 |
-
from .
|
12 |
-
from
|
13 |
TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
|
14 |
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
15 |
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
|
|
8 |
from typing import Dict, Any, List, Optional, Tuple
|
9 |
from pathlib import Path
|
10 |
|
11 |
+
from vms.utils import BaseTab
|
12 |
+
from vms.config import (
|
13 |
TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS, TRAINING_TYPES,
|
14 |
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS,
|
15 |
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P,
|
vms/utils/__init__.py
CHANGED
@@ -9,6 +9,7 @@ from .finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_trai
|
|
9 |
from . import webdataset_handler
|
10 |
|
11 |
from .gpu_detector import get_available_gpu_count, get_gpu_info, get_recommended_precomputation_items
|
|
|
12 |
|
13 |
__all__ = [
|
14 |
'validate_model_repo',
|
@@ -39,5 +40,7 @@ __all__ = [
|
|
39 |
|
40 |
'get_available_gpu_count',
|
41 |
'get_gpu_info',
|
42 |
-
'get_recommended_precomputation_items'
|
|
|
|
|
43 |
]
|
|
|
9 |
from . import webdataset_handler
|
10 |
|
11 |
from .gpu_detector import get_available_gpu_count, get_gpu_info, get_recommended_precomputation_items
|
12 |
+
from .base_tab import BaseTab
|
13 |
|
14 |
__all__ = [
|
15 |
'validate_model_repo',
|
|
|
40 |
|
41 |
'get_available_gpu_count',
|
42 |
'get_gpu_info',
|
43 |
+
'get_recommended_precomputation_items',
|
44 |
+
|
45 |
+
'BaseTab',
|
46 |
]
|
vms/{tabs β utils}/base_tab.py
RENAMED
@@ -15,7 +15,7 @@ class BaseTab:
|
|
15 |
"""Initialize the tab with app state reference
|
16 |
|
17 |
Args:
|
18 |
-
app_state: Reference to main
|
19 |
"""
|
20 |
self.app = app_state
|
21 |
self.components = {}
|
|
|
15 |
"""Initialize the tab with app state reference
|
16 |
|
17 |
Args:
|
18 |
+
app_state: Reference to main AppUI instance
|
19 |
"""
|
20 |
self.app = app_state
|
21 |
self.components = {}
|