Spaces:
Running
Running
Commit
·
ab45a2c
1
Parent(s):
05b8a89
fixed ui persistence issues
Browse files- vms/ui/app_ui.py +32 -24
- vms/ui/project/tabs/__init__.py +0 -2
- vms/ui/project/tabs/caption_tab.py +1 -1
- vms/ui/project/tabs/import_tab/hub_tab.py +30 -0
- vms/ui/project/tabs/import_tab/import_tab.py +65 -29
- vms/ui/project/tabs/import_tab/upload_tab.py +8 -9
- vms/ui/project/tabs/import_tab/youtube_tab.py +9 -3
- vms/ui/project/tabs/manage_tab.py +1 -5
- vms/ui/project/tabs/preview_tab.py +22 -17
- vms/ui/project/tabs/split_tab.py +0 -81
- vms/ui/project/tabs/train_tab.py +86 -24
vms/ui/app_ui.py
CHANGED
@@ -32,7 +32,7 @@ from vms.ui.project.services import (
|
|
32 |
TrainingService, CaptioningService, SplittingService, ImportingService, PreviewingService
|
33 |
)
|
34 |
from vms.ui.project.tabs import (
|
35 |
-
ImportTab,
|
36 |
)
|
37 |
|
38 |
from vms.ui.monitoring.services import (
|
@@ -164,7 +164,6 @@ class AppUI:
|
|
164 |
|
165 |
# Initialize project tab objects
|
166 |
self.project_tabs["import_tab"] = ImportTab(self)
|
167 |
-
self.project_tabs["split_tab"] = SplitTab(self)
|
168 |
self.project_tabs["caption_tab"] = CaptionTab(self)
|
169 |
self.project_tabs["train_tab"] = TrainTab(self)
|
170 |
self.project_tabs["preview_tab"] = PreviewTab(self)
|
@@ -213,7 +212,6 @@ class AppUI:
|
|
213 |
app.load(
|
214 |
fn=self.initialize_app_state,
|
215 |
outputs=[
|
216 |
-
self.project_tabs["split_tab"].components["video_list"],
|
217 |
self.project_tabs["caption_tab"].components["training_dataset"],
|
218 |
self.project_tabs["train_tab"].components["start_btn"],
|
219 |
self.project_tabs["train_tab"].components["stop_btn"],
|
@@ -273,7 +271,6 @@ class AppUI:
|
|
273 |
dataset_timer.tick(
|
274 |
fn=self.refresh_dataset,
|
275 |
outputs=[
|
276 |
-
self.project_tabs["split_tab"].components["video_list"],
|
277 |
self.project_tabs["caption_tab"].components["training_dataset"]
|
278 |
]
|
279 |
)
|
@@ -283,7 +280,6 @@ class AppUI:
|
|
283 |
titles_timer.tick(
|
284 |
fn=self.update_titles,
|
285 |
outputs=[
|
286 |
-
self.project_tabs["split_tab"].components["split_title"],
|
287 |
self.project_tabs["caption_tab"].components["caption_title"],
|
288 |
self.project_tabs["train_tab"].components["train_title"]
|
289 |
]
|
@@ -292,7 +288,6 @@ class AppUI:
|
|
292 |
def initialize_app_state(self):
|
293 |
"""Initialize all app state in one function to ensure correct output count"""
|
294 |
# Get dataset info
|
295 |
-
video_list = self.project_tabs["split_tab"].list_unprocessed_videos()
|
296 |
training_dataset = self.project_tabs["caption_tab"].list_training_files_to_caption()
|
297 |
|
298 |
# Get button states based on recovery status
|
@@ -381,17 +376,40 @@ class AppUI:
|
|
381 |
|
382 |
# Get model_version value
|
383 |
model_version_val = ""
|
|
|
384 |
# First get the internal model type for the currently selected model
|
385 |
model_internal_type = MODEL_TYPES.get(model_type_val)
|
|
|
|
|
386 |
if model_internal_type and model_internal_type in MODEL_VERSIONS:
|
387 |
-
#
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
model_version_val = ui_state["model_version"]
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
|
396 |
# Ensure training_type is a valid display name
|
397 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
@@ -444,7 +462,6 @@ class AppUI:
|
|
444 |
|
445 |
# Return all values in the exact order expected by outputs
|
446 |
return (
|
447 |
-
video_list,
|
448 |
training_dataset,
|
449 |
start_btn,
|
450 |
stop_btn,
|
@@ -464,7 +481,7 @@ class AppUI:
|
|
464 |
precomputation_items_val,
|
465 |
lr_warmup_steps_val
|
466 |
)
|
467 |
-
|
468 |
def initialize_ui_from_state(self):
|
469 |
"""Initialize UI components from saved state"""
|
470 |
ui_state = self.load_ui_values()
|
@@ -558,12 +575,6 @@ class AppUI:
|
|
558 |
Returns:
|
559 |
Dict of Gradio updates
|
560 |
"""
|
561 |
-
# Count files for splitting
|
562 |
-
split_videos, _, split_size = count_media_files(VIDEOS_TO_SPLIT_PATH)
|
563 |
-
split_title = format_media_title(
|
564 |
-
"split", split_videos, 0, split_size
|
565 |
-
)
|
566 |
-
|
567 |
# Count files for captioning
|
568 |
caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH)
|
569 |
caption_title = format_media_title(
|
@@ -577,17 +588,14 @@ class AppUI:
|
|
577 |
)
|
578 |
|
579 |
return (
|
580 |
-
gr.Markdown(value=split_title),
|
581 |
gr.Markdown(value=caption_title),
|
582 |
gr.Markdown(value=f"{train_title} available for training")
|
583 |
)
|
584 |
|
585 |
def refresh_dataset(self):
|
586 |
"""Refresh all dynamic lists and training state"""
|
587 |
-
video_list = self.project_tabs["split_tab"].list_unprocessed_videos()
|
588 |
training_dataset = self.project_tabs["caption_tab"].list_training_files_to_caption()
|
589 |
|
590 |
return (
|
591 |
-
video_list,
|
592 |
training_dataset
|
593 |
)
|
|
|
32 |
TrainingService, CaptioningService, SplittingService, ImportingService, PreviewingService
|
33 |
)
|
34 |
from vms.ui.project.tabs import (
|
35 |
+
ImportTab, CaptionTab, TrainTab, PreviewTab, ManageTab
|
36 |
)
|
37 |
|
38 |
from vms.ui.monitoring.services import (
|
|
|
164 |
|
165 |
# Initialize project tab objects
|
166 |
self.project_tabs["import_tab"] = ImportTab(self)
|
|
|
167 |
self.project_tabs["caption_tab"] = CaptionTab(self)
|
168 |
self.project_tabs["train_tab"] = TrainTab(self)
|
169 |
self.project_tabs["preview_tab"] = PreviewTab(self)
|
|
|
212 |
app.load(
|
213 |
fn=self.initialize_app_state,
|
214 |
outputs=[
|
|
|
215 |
self.project_tabs["caption_tab"].components["training_dataset"],
|
216 |
self.project_tabs["train_tab"].components["start_btn"],
|
217 |
self.project_tabs["train_tab"].components["stop_btn"],
|
|
|
271 |
dataset_timer.tick(
|
272 |
fn=self.refresh_dataset,
|
273 |
outputs=[
|
|
|
274 |
self.project_tabs["caption_tab"].components["training_dataset"]
|
275 |
]
|
276 |
)
|
|
|
280 |
titles_timer.tick(
|
281 |
fn=self.update_titles,
|
282 |
outputs=[
|
|
|
283 |
self.project_tabs["caption_tab"].components["caption_title"],
|
284 |
self.project_tabs["train_tab"].components["train_title"]
|
285 |
]
|
|
|
288 |
def initialize_app_state(self):
|
289 |
"""Initialize all app state in one function to ensure correct output count"""
|
290 |
# Get dataset info
|
|
|
291 |
training_dataset = self.project_tabs["caption_tab"].list_training_files_to_caption()
|
292 |
|
293 |
# Get button states based on recovery status
|
|
|
376 |
|
377 |
# Get model_version value
|
378 |
model_version_val = ""
|
379 |
+
|
380 |
# First get the internal model type for the currently selected model
|
381 |
model_internal_type = MODEL_TYPES.get(model_type_val)
|
382 |
+
logger.info(f"Initializing model version for model_type: {model_type_val} (internal: {model_internal_type})")
|
383 |
+
|
384 |
if model_internal_type and model_internal_type in MODEL_VERSIONS:
|
385 |
+
# Get available versions for this model type as simple strings
|
386 |
+
available_model_versions = list(MODEL_VERSIONS.get(model_internal_type, {}).keys())
|
387 |
+
|
388 |
+
# Log for debugging
|
389 |
+
logger.info(f"Available versions: {available_model_versions}")
|
390 |
+
|
391 |
+
# Set model_version_val to saved value if valid, otherwise first available
|
392 |
+
if "model_version" in ui_state and ui_state["model_version"] in available_model_versions:
|
393 |
model_version_val = ui_state["model_version"]
|
394 |
+
logger.info(f"Using saved model version: {model_version_val}")
|
395 |
+
elif available_model_versions:
|
396 |
+
model_version_val = available_model_versions[0]
|
397 |
+
logger.info(f"Using first available model version: {model_version_val}")
|
398 |
+
|
399 |
+
# IMPORTANT: Update the dropdown choices directly in the UI component
|
400 |
+
# This is essential to avoid the error when loading the UI
|
401 |
+
try:
|
402 |
+
self.project_tabs["train_tab"].components["model_version"].choices = available_model_versions
|
403 |
+
logger.info(f"Updated model_version dropdown choices: {len(available_model_versions)} options")
|
404 |
+
except Exception as e:
|
405 |
+
logger.error(f"Error updating model_version dropdown: {str(e)}")
|
406 |
+
else:
|
407 |
+
logger.warning(f"No versions available for model type: {model_type_val}")
|
408 |
+
# Set empty choices to avoid errors
|
409 |
+
try:
|
410 |
+
self.project_tabs["train_tab"].components["model_version"].choices = []
|
411 |
+
except Exception as e:
|
412 |
+
logger.error(f"Error setting empty model_version choices: {str(e)}")
|
413 |
|
414 |
# Ensure training_type is a valid display name
|
415 |
training_type_val = ui_state.get("training_type", list(TRAINING_TYPES.keys())[0])
|
|
|
462 |
|
463 |
# Return all values in the exact order expected by outputs
|
464 |
return (
|
|
|
465 |
training_dataset,
|
466 |
start_btn,
|
467 |
stop_btn,
|
|
|
481 |
precomputation_items_val,
|
482 |
lr_warmup_steps_val
|
483 |
)
|
484 |
+
|
485 |
def initialize_ui_from_state(self):
|
486 |
"""Initialize UI components from saved state"""
|
487 |
ui_state = self.load_ui_values()
|
|
|
575 |
Returns:
|
576 |
Dict of Gradio updates
|
577 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
578 |
# Count files for captioning
|
579 |
caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH)
|
580 |
caption_title = format_media_title(
|
|
|
588 |
)
|
589 |
|
590 |
return (
|
|
|
591 |
gr.Markdown(value=caption_title),
|
592 |
gr.Markdown(value=f"{train_title} available for training")
|
593 |
)
|
594 |
|
595 |
def refresh_dataset(self):
|
596 |
"""Refresh all dynamic lists and training state"""
|
|
|
597 |
training_dataset = self.project_tabs["caption_tab"].list_training_files_to_caption()
|
598 |
|
599 |
return (
|
|
|
600 |
training_dataset
|
601 |
)
|
vms/ui/project/tabs/__init__.py
CHANGED
@@ -3,7 +3,6 @@ Tab components for the "project" view
|
|
3 |
"""
|
4 |
|
5 |
from .import_tab import ImportTab
|
6 |
-
from .split_tab import SplitTab
|
7 |
from .caption_tab import CaptionTab
|
8 |
from .train_tab import TrainTab
|
9 |
from .preview_tab import PreviewTab
|
@@ -11,7 +10,6 @@ from .manage_tab import ManageTab
|
|
11 |
|
12 |
__all__ = [
|
13 |
'ImportTab',
|
14 |
-
'SplitTab',
|
15 |
'CaptionTab',
|
16 |
'TrainTab',
|
17 |
'PreviewTab',
|
|
|
3 |
"""
|
4 |
|
5 |
from .import_tab import ImportTab
|
|
|
6 |
from .caption_tab import CaptionTab
|
7 |
from .train_tab import TrainTab
|
8 |
from .preview_tab import PreviewTab
|
|
|
10 |
|
11 |
__all__ = [
|
12 |
'ImportTab',
|
|
|
13 |
'CaptionTab',
|
14 |
'TrainTab',
|
15 |
'PreviewTab',
|
vms/ui/project/tabs/caption_tab.py
CHANGED
@@ -21,7 +21,7 @@ class CaptionTab(BaseTab):
|
|
21 |
def __init__(self, app_state):
|
22 |
super().__init__(app_state)
|
23 |
self.id = "caption_tab"
|
24 |
-
self.title = "
|
25 |
self._should_stop_captioning = False
|
26 |
|
27 |
def create(self, parent=None) -> gr.TabItem:
|
|
|
21 |
def __init__(self, app_state):
|
22 |
super().__init__(app_state)
|
23 |
self.id = "caption_tab"
|
24 |
+
self.title = "2️⃣ Caption"
|
25 |
self._should_stop_captioning = False
|
26 |
|
27 |
def create(self, parent=None) -> gr.TabItem:
|
vms/ui/project/tabs/import_tab/hub_tab.py
CHANGED
@@ -124,6 +124,14 @@ class HubTab(BaseTab):
|
|
124 |
]
|
125 |
)
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
# Download videos button
|
128 |
self.components["download_videos_btn"].click(
|
129 |
fn=self.set_file_type_and_return,
|
@@ -142,6 +150,17 @@ class HubTab(BaseTab):
|
|
142 |
self.components["download_webdataset_btn"],
|
143 |
self.components["download_in_progress"]
|
144 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
)
|
146 |
|
147 |
# Download WebDataset button
|
@@ -162,6 +181,17 @@ class HubTab(BaseTab):
|
|
162 |
self.components["download_webdataset_btn"],
|
163 |
self.components["download_in_progress"]
|
164 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
)
|
166 |
|
167 |
def set_file_type_and_return(self):
|
|
|
124 |
]
|
125 |
)
|
126 |
|
127 |
+
# Check if we have access to project_tabs_component
|
128 |
+
if hasattr(self.app, "project_tabs_component"):
|
129 |
+
tabs_component = self.app.project_tabs_component
|
130 |
+
else:
|
131 |
+
# Fallback to prevent errors
|
132 |
+
logger.warning("project_tabs_component not found in app, using None for tab switching")
|
133 |
+
tabs_component = None
|
134 |
+
|
135 |
# Download videos button
|
136 |
self.components["download_videos_btn"].click(
|
137 |
fn=self.set_file_type_and_return,
|
|
|
150 |
self.components["download_webdataset_btn"],
|
151 |
self.components["download_in_progress"]
|
152 |
]
|
153 |
+
).success(
|
154 |
+
fn=self.app.tabs["import_tab"].on_import_success,
|
155 |
+
inputs=[
|
156 |
+
self.components["enable_automatic_video_split"],
|
157 |
+
self.components["enable_automatic_content_captioning"],
|
158 |
+
self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
159 |
+
],
|
160 |
+
outputs=[
|
161 |
+
tabs_component,
|
162 |
+
self.components["status_output"]
|
163 |
+
]
|
164 |
)
|
165 |
|
166 |
# Download WebDataset button
|
|
|
181 |
self.components["download_webdataset_btn"],
|
182 |
self.components["download_in_progress"]
|
183 |
]
|
184 |
+
).success(
|
185 |
+
fn=self.app.tabs["import_tab"].on_import_success,
|
186 |
+
inputs=[
|
187 |
+
self.components["enable_automatic_video_split"],
|
188 |
+
self.components["enable_automatic_content_captioning"],
|
189 |
+
self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
190 |
+
],
|
191 |
+
outputs=[
|
192 |
+
tabs_component,
|
193 |
+
self.components["status_output"]
|
194 |
+
]
|
195 |
)
|
196 |
|
197 |
def set_file_type_and_return(self):
|
vms/ui/project/tabs/import_tab/import_tab.py
CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
|
|
6 |
import logging
|
7 |
import asyncio
|
8 |
import threading
|
|
|
9 |
from pathlib import Path
|
10 |
from typing import Dict, Any, List, Optional, Tuple
|
11 |
|
@@ -91,27 +92,35 @@ class ImportTab(BaseTab):
|
|
91 |
|
92 |
def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
93 |
"""Handle successful import of files"""
|
94 |
-
videos
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
# Start the scene detection in a separate thread
|
100 |
-
self._start_scene_detection_bg(enable_splitting)
|
101 |
-
msg = "Starting automatic scene detection..."
|
102 |
else:
|
103 |
-
#
|
104 |
-
self.
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
108 |
|
|
|
|
|
|
|
109 |
# Start auto-captioning if enabled
|
110 |
if enable_automatic_content_captioning:
|
111 |
self._start_captioning_bg(DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, prompt_prefix)
|
112 |
|
113 |
-
#
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
def _start_scene_detection_bg(self, enable_splitting):
|
117 |
"""Start scene detection in a background thread"""
|
@@ -120,7 +129,7 @@ class ImportTab(BaseTab):
|
|
120 |
asyncio.set_event_loop(loop)
|
121 |
try:
|
122 |
loop.run_until_complete(
|
123 |
-
self.app.
|
124 |
)
|
125 |
except Exception as e:
|
126 |
logger.error(f"Error in background scene detection: {str(e)}", exc_info=True)
|
@@ -131,21 +140,48 @@ class ImportTab(BaseTab):
|
|
131 |
thread.daemon = True
|
132 |
thread.start()
|
133 |
|
134 |
-
def
|
135 |
-
"""Start copying files in a background thread"""
|
136 |
def run_async_in_thread():
|
137 |
-
loop = asyncio.new_event_loop()
|
138 |
-
asyncio.set_event_loop(loop)
|
139 |
try:
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
-
loop.run_until_complete(copy_files())
|
145 |
except Exception as e:
|
146 |
-
logger.error(f"Error in background file copying: {str(e)}", exc_info=True)
|
147 |
-
finally:
|
148 |
-
loop.close()
|
149 |
|
150 |
thread = threading.Thread(target=run_async_in_thread)
|
151 |
thread.daemon = True
|
@@ -174,7 +210,7 @@ class ImportTab(BaseTab):
|
|
174 |
async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
175 |
"""Handle post-import updates including titles"""
|
176 |
# Call the non-async version since we need to return immediately for the UI
|
177 |
-
tabs,
|
178 |
enable_splitting, enable_automatic_content_captioning, prompt_prefix
|
179 |
)
|
180 |
|
@@ -182,4 +218,4 @@ class ImportTab(BaseTab):
|
|
182 |
titles = self.app.update_titles()
|
183 |
|
184 |
# Return all expected outputs
|
185 |
-
return tabs,
|
|
|
6 |
import logging
|
7 |
import asyncio
|
8 |
import threading
|
9 |
+
import shutil
|
10 |
from pathlib import Path
|
11 |
from typing import Dict, Any, List, Optional, Tuple
|
12 |
|
|
|
92 |
|
93 |
def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
94 |
"""Handle successful import of files"""
|
95 |
+
# If splitting is disabled, we need to directly move videos to staging
|
96 |
+
if not enable_splitting:
|
97 |
+
# Copy files without splitting
|
98 |
+
self._start_copy_to_staging_bg()
|
99 |
+
msg = "Copying videos to staging directory without splitting..."
|
|
|
|
|
|
|
100 |
else:
|
101 |
+
# Start scene detection if not already running and there are videos to process
|
102 |
+
if not self.app.splitting.is_processing():
|
103 |
+
# Start the scene detection in a separate thread
|
104 |
+
self._start_scene_detection_bg(enable_splitting)
|
105 |
+
msg = "Starting automatic scene detection..."
|
106 |
+
else:
|
107 |
+
msg = "Scene detection already running..."
|
108 |
|
109 |
+
# Copy files to training directory
|
110 |
+
self.app.tabs["caption_tab"].copy_files_to_training_dir(prompt_prefix)
|
111 |
+
|
112 |
# Start auto-captioning if enabled
|
113 |
if enable_automatic_content_captioning:
|
114 |
self._start_captioning_bg(DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, prompt_prefix)
|
115 |
|
116 |
+
# Check if we have access to project_tabs_component for tab switching
|
117 |
+
if hasattr(self.app, "project_tabs_component") and self.app.project_tabs_component is not None:
|
118 |
+
# Now redirect to the caption tab instead of split tab
|
119 |
+
return gr.update(selected="caption_tab"), msg
|
120 |
+
else:
|
121 |
+
# If no tabs component is available, just return the message
|
122 |
+
logger.warning("Cannot switch tabs - project_tabs_component not available")
|
123 |
+
return None, msg
|
124 |
|
125 |
def _start_scene_detection_bg(self, enable_splitting):
|
126 |
"""Start scene detection in a background thread"""
|
|
|
129 |
asyncio.set_event_loop(loop)
|
130 |
try:
|
131 |
loop.run_until_complete(
|
132 |
+
self.app.splitting.start_processing(enable_splitting)
|
133 |
)
|
134 |
except Exception as e:
|
135 |
logger.error(f"Error in background scene detection: {str(e)}", exc_info=True)
|
|
|
140 |
thread.daemon = True
|
141 |
thread.start()
|
142 |
|
143 |
+
def _start_copy_to_staging_bg(self):
|
144 |
+
"""Start copying files directly to staging directory in a background thread"""
|
145 |
def run_async_in_thread():
|
|
|
|
|
146 |
try:
|
147 |
+
# Copy all videos from videos_to_split to staging without scene detection
|
148 |
+
for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
|
149 |
+
try:
|
150 |
+
# Ensure unique filename in staging directory
|
151 |
+
target_path = STAGING_PATH / video_file.name
|
152 |
+
counter = 1
|
153 |
+
|
154 |
+
while target_path.exists():
|
155 |
+
stem = video_file.stem
|
156 |
+
if "___" in stem:
|
157 |
+
base_stem = stem.split("___")[0]
|
158 |
+
else:
|
159 |
+
base_stem = stem
|
160 |
+
target_path = STAGING_PATH / f"{base_stem}___{counter}{video_file.suffix}"
|
161 |
+
counter += 1
|
162 |
+
|
163 |
+
# Copy the video file to staging
|
164 |
+
shutil.copy2(video_file, target_path)
|
165 |
+
logger.info(f"Copied video directly to staging: {video_file.name} -> {target_path.name}")
|
166 |
+
|
167 |
+
# Copy caption file if it exists
|
168 |
+
caption_path = video_file.with_suffix('.txt')
|
169 |
+
if caption_path.exists():
|
170 |
+
shutil.copy2(caption_path, target_path.with_suffix('.txt'))
|
171 |
+
logger.info(f"Copied caption for {video_file.name}")
|
172 |
+
|
173 |
+
# Remove original after successful copy
|
174 |
+
video_file.unlink()
|
175 |
+
if caption_path.exists():
|
176 |
+
caption_path.unlink()
|
177 |
+
|
178 |
+
gr.Info(f"Imported {video_file.name} directly to staging")
|
179 |
+
|
180 |
+
except Exception as e:
|
181 |
+
logger.error(f"Error copying {video_file.name} to staging: {str(e)}", exc_info=True)
|
182 |
|
|
|
183 |
except Exception as e:
|
184 |
+
logger.error(f"Error in background file copying to staging: {str(e)}", exc_info=True)
|
|
|
|
|
185 |
|
186 |
thread = threading.Thread(target=run_async_in_thread)
|
187 |
thread.daemon = True
|
|
|
210 |
async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
211 |
"""Handle post-import updates including titles"""
|
212 |
# Call the non-async version since we need to return immediately for the UI
|
213 |
+
tabs, status_msg = self.on_import_success(
|
214 |
enable_splitting, enable_automatic_content_captioning, prompt_prefix
|
215 |
)
|
216 |
|
|
|
218 |
titles = self.app.update_titles()
|
219 |
|
220 |
# Return all expected outputs
|
221 |
+
return tabs, status_msg, *titles
|
vms/ui/project/tabs/import_tab/upload_tab.py
CHANGED
@@ -70,16 +70,17 @@ class UploadTab(BaseTab):
|
|
70 |
)
|
71 |
|
72 |
# Only add success handler if all required components exist
|
73 |
-
if hasattr(self.app.tabs, "import_tab") and hasattr(self.app.tabs, "
|
74 |
-
hasattr(self.app.tabs, "caption_tab") and hasattr(self.app.tabs, "train_tab"):
|
75 |
|
76 |
# Get required components for success handler
|
77 |
try:
|
78 |
# If the components are missing, this will raise an AttributeError
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
83 |
caption_title = self.app.tabs["caption_tab"].components["caption_title"]
|
84 |
train_title = self.app.tabs["train_tab"].components["train_title"]
|
85 |
custom_prompt_prefix = self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
@@ -94,9 +95,7 @@ class UploadTab(BaseTab):
|
|
94 |
],
|
95 |
outputs=[
|
96 |
tabs_component,
|
97 |
-
|
98 |
-
detect_status,
|
99 |
-
split_title,
|
100 |
caption_title,
|
101 |
train_title
|
102 |
]
|
|
|
70 |
)
|
71 |
|
72 |
# Only add success handler if all required components exist
|
73 |
+
if hasattr(self.app.tabs, "import_tab") and hasattr(self.app.tabs, "caption_tab") and hasattr(self.app.tabs, "train_tab"):
|
|
|
74 |
|
75 |
# Get required components for success handler
|
76 |
try:
|
77 |
# If the components are missing, this will raise an AttributeError
|
78 |
+
if hasattr(self.app, "project_tabs_component"):
|
79 |
+
tabs_component = self.app.project_tabs_component
|
80 |
+
else:
|
81 |
+
logger.warning("project_tabs_component not found in app, using None for tab switching")
|
82 |
+
tabs_component = None
|
83 |
+
|
84 |
caption_title = self.app.tabs["caption_tab"].components["caption_title"]
|
85 |
train_title = self.app.tabs["train_tab"].components["train_title"]
|
86 |
custom_prompt_prefix = self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
|
|
95 |
],
|
96 |
outputs=[
|
97 |
tabs_component,
|
98 |
+
self.components["import_status"],
|
|
|
|
|
99 |
caption_title,
|
100 |
train_title
|
101 |
]
|
vms/ui/project/tabs/import_tab/youtube_tab.py
CHANGED
@@ -74,6 +74,13 @@ class YouTubeTab(BaseTab):
|
|
74 |
except (AttributeError, KeyError):
|
75 |
logger.warning("Could not access custom_prompt_prefix component")
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
# YouTube download event
|
78 |
download_event = self.components["youtube_download_btn"].click(
|
79 |
fn=self.app.importing.download_youtube_video,
|
@@ -93,9 +100,8 @@ class YouTubeTab(BaseTab):
|
|
93 |
custom_prompt_prefix
|
94 |
],
|
95 |
outputs=[
|
96 |
-
|
97 |
-
self.
|
98 |
-
self.app.tabs["split_tab"].components["detect_status"]
|
99 |
]
|
100 |
)
|
101 |
except (AttributeError, KeyError) as e:
|
|
|
74 |
except (AttributeError, KeyError):
|
75 |
logger.warning("Could not access custom_prompt_prefix component")
|
76 |
|
77 |
+
# Check if we have access to project_tabs_component
|
78 |
+
if hasattr(self.app, "project_tabs_component"):
|
79 |
+
tabs_component = self.app.project_tabs_component
|
80 |
+
else:
|
81 |
+
logger.warning("project_tabs_component not found in app, using None for tab switching")
|
82 |
+
tabs_component = None
|
83 |
+
|
84 |
# YouTube download event
|
85 |
download_event = self.components["youtube_download_btn"].click(
|
86 |
fn=self.app.importing.download_youtube_video,
|
|
|
100 |
custom_prompt_prefix
|
101 |
],
|
102 |
outputs=[
|
103 |
+
tabs_component,
|
104 |
+
self.components["import_status"]
|
|
|
105 |
]
|
106 |
)
|
107 |
except (AttributeError, KeyError) as e:
|
vms/ui/project/tabs/manage_tab.py
CHANGED
@@ -22,7 +22,7 @@ class ManageTab(BaseTab):
|
|
22 |
def __init__(self, app_state):
|
23 |
super().__init__(app_state)
|
24 |
self.id = "manage_tab"
|
25 |
-
self.title = "
|
26 |
|
27 |
def create(self, parent=None) -> gr.TabItem:
|
28 |
"""Create the Manage tab UI components"""
|
@@ -108,11 +108,9 @@ class ManageTab(BaseTab):
|
|
108 |
fn=self.handle_global_stop,
|
109 |
outputs=[
|
110 |
self.components["global_status"],
|
111 |
-
self.app.tabs["split_tab"].components["video_list"],
|
112 |
self.app.tabs["caption_tab"].components["training_dataset"],
|
113 |
self.app.tabs["train_tab"].components["status_box"],
|
114 |
self.app.tabs["train_tab"].components["log_box"],
|
115 |
-
self.app.tabs["split_tab"].components["detect_status"],
|
116 |
self.app.tabs["import_tab"].components["import_status"],
|
117 |
self.app.tabs["caption_tab"].components["preview_status"]
|
118 |
]
|
@@ -169,11 +167,9 @@ class ManageTab(BaseTab):
|
|
169 |
|
170 |
return {
|
171 |
self.components["global_status"]: gr.update(value=full_status, visible=True),
|
172 |
-
self.app.tabs["split_tab"].components["video_list"]: videos,
|
173 |
self.app.tabs["caption_tab"].components["training_dataset"]: clips,
|
174 |
self.app.tabs["train_tab"].components["status_box"]: "Training stopped and data cleared",
|
175 |
self.app.tabs["train_tab"].components["log_box"]: "",
|
176 |
-
self.app.tabs["split_tab"].components["detect_status"]: "Scene detection stopped",
|
177 |
self.app.tabs["import_tab"].components["import_status"]: "All data cleared",
|
178 |
self.app.tabs["caption_tab"].components["preview_status"]: "Captioning stopped"
|
179 |
}
|
|
|
22 |
def __init__(self, app_state):
|
23 |
super().__init__(app_state)
|
24 |
self.id = "manage_tab"
|
25 |
+
self.title = "5️⃣ Storage"
|
26 |
|
27 |
def create(self, parent=None) -> gr.TabItem:
|
28 |
"""Create the Manage tab UI components"""
|
|
|
108 |
fn=self.handle_global_stop,
|
109 |
outputs=[
|
110 |
self.components["global_status"],
|
|
|
111 |
self.app.tabs["caption_tab"].components["training_dataset"],
|
112 |
self.app.tabs["train_tab"].components["status_box"],
|
113 |
self.app.tabs["train_tab"].components["log_box"],
|
|
|
114 |
self.app.tabs["import_tab"].components["import_status"],
|
115 |
self.app.tabs["caption_tab"].components["preview_status"]
|
116 |
]
|
|
|
167 |
|
168 |
return {
|
169 |
self.components["global_status"]: gr.update(value=full_status, visible=True),
|
|
|
170 |
self.app.tabs["caption_tab"].components["training_dataset"]: clips,
|
171 |
self.app.tabs["train_tab"].components["status_box"]: "Training stopped and data cleared",
|
172 |
self.app.tabs["train_tab"].components["log_box"]: "",
|
|
|
173 |
self.app.tabs["import_tab"].components["import_status"]: "All data cleared",
|
174 |
self.app.tabs["caption_tab"].components["preview_status"]: "Captioning stopped"
|
175 |
}
|
vms/ui/project/tabs/preview_tab.py
CHANGED
@@ -23,7 +23,7 @@ class PreviewTab(BaseTab):
|
|
23 |
def __init__(self, app_state):
|
24 |
super().__init__(app_state)
|
25 |
self.id = "preview_tab"
|
26 |
-
self.title = "
|
27 |
|
28 |
def create(self, parent=None) -> gr.TabItem:
|
29 |
"""Create the Preview tab UI components"""
|
@@ -193,26 +193,31 @@ class PreviewTab(BaseTab):
|
|
193 |
"""Get model version choices based on model type"""
|
194 |
# Convert UI display name to internal name
|
195 |
internal_type = MODEL_TYPES.get(model_type)
|
196 |
-
if not internal_type:
|
197 |
-
|
198 |
-
|
199 |
-
# Get versions from preview service
|
200 |
-
versions = self.app.previewing.get_model_versions(internal_type)
|
201 |
-
if not versions:
|
202 |
return []
|
203 |
|
204 |
-
#
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
|
209 |
-
return choices
|
210 |
-
|
211 |
def get_default_model_version(self, model_type: str) -> str:
|
212 |
-
"""Get default model version for the model type"""
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
return ""
|
217 |
|
218 |
def get_default_model_type(self) -> str:
|
|
|
23 |
def __init__(self, app_state):
|
24 |
super().__init__(app_state)
|
25 |
self.id = "preview_tab"
|
26 |
+
self.title = "4️⃣ Preview"
|
27 |
|
28 |
def create(self, parent=None) -> gr.TabItem:
|
29 |
"""Create the Preview tab UI components"""
|
|
|
193 |
"""Get model version choices based on model type"""
|
194 |
# Convert UI display name to internal name
|
195 |
internal_type = MODEL_TYPES.get(model_type)
|
196 |
+
if not internal_type or internal_type not in MODEL_VERSIONS:
|
197 |
+
logger.warning(f"No model versions found for {model_type} (internal type: {internal_type})")
|
|
|
|
|
|
|
|
|
198 |
return []
|
199 |
|
200 |
+
# Return just the model IDs as a list of simple strings
|
201 |
+
version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys())
|
202 |
+
logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}")
|
203 |
+
return version_ids
|
204 |
|
|
|
|
|
205 |
def get_default_model_version(self, model_type: str) -> str:
|
206 |
+
"""Get default model version for the given model type"""
|
207 |
+
# Convert UI display name to internal name
|
208 |
+
internal_type = MODEL_TYPES.get(model_type)
|
209 |
+
logger.debug(f"get_default_model_version({model_type}) = {internal_type}")
|
210 |
+
|
211 |
+
if not internal_type or internal_type not in MODEL_VERSIONS:
|
212 |
+
logger.warning(f"No valid model versions found for {model_type}")
|
213 |
+
return ""
|
214 |
+
|
215 |
+
# Get the first version available for this model type
|
216 |
+
versions = list(MODEL_VERSIONS.get(internal_type, {}).keys())
|
217 |
+
if versions:
|
218 |
+
default_version = versions[0]
|
219 |
+
logger.debug(f"Default version for {model_type}: {default_version}")
|
220 |
+
return default_version
|
221 |
return ""
|
222 |
|
223 |
def get_default_model_type(self) -> str:
|
vms/ui/project/tabs/split_tab.py
DELETED
@@ -1,81 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Split tab for Video Model Studio UI
|
3 |
-
"""
|
4 |
-
|
5 |
-
import gradio as gr
|
6 |
-
import logging
|
7 |
-
from typing import Dict, Any, List, Optional
|
8 |
-
|
9 |
-
from vms.utils import BaseTab
|
10 |
-
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
-
|
13 |
-
class SplitTab(BaseTab):
|
14 |
-
"""Split tab for scene detection and video splitting"""
|
15 |
-
|
16 |
-
def __init__(self, app_state):
|
17 |
-
super().__init__(app_state)
|
18 |
-
self.id = "split_tab"
|
19 |
-
self.title = "2️⃣ Split"
|
20 |
-
|
21 |
-
def create(self, parent=None) -> gr.TabItem:
|
22 |
-
"""Create the Split tab UI components"""
|
23 |
-
with gr.TabItem(self.title, id=self.id) as tab:
|
24 |
-
with gr.Row():
|
25 |
-
self.components["split_title"] = gr.Markdown("## Splitting of 0 videos (0 bytes)")
|
26 |
-
|
27 |
-
with gr.Row():
|
28 |
-
with gr.Column():
|
29 |
-
self.components["detect_btn"] = gr.Button("Split videos into single-camera shots", variant="primary")
|
30 |
-
self.components["detect_status"] = gr.Textbox(label="Status", interactive=False)
|
31 |
-
|
32 |
-
with gr.Column():
|
33 |
-
self.components["video_list"] = gr.Dataframe(
|
34 |
-
headers=["name", "status"],
|
35 |
-
label="Videos to split (note: Nvidia A100 cannot split videos encoded in AV1)",
|
36 |
-
interactive=False,
|
37 |
-
wrap=True
|
38 |
-
)
|
39 |
-
|
40 |
-
return tab
|
41 |
-
|
42 |
-
def connect_events(self) -> None:
|
43 |
-
"""Connect event handlers to UI components"""
|
44 |
-
# Scene detection button event
|
45 |
-
self.components["detect_btn"].click(
|
46 |
-
fn=self.start_scene_detection,
|
47 |
-
inputs=[self.app.tabs["import_tab"].components["enable_automatic_video_split"]],
|
48 |
-
outputs=[self.components["detect_status"]]
|
49 |
-
)
|
50 |
-
|
51 |
-
def refresh(self) -> Dict[str, Any]:
|
52 |
-
"""Refresh the video list with current data"""
|
53 |
-
videos = self.list_unprocessed_videos()
|
54 |
-
return {
|
55 |
-
"video_list": videos
|
56 |
-
}
|
57 |
-
|
58 |
-
def list_unprocessed_videos(self) -> gr.Dataframe:
|
59 |
-
"""Update list of unprocessed videos"""
|
60 |
-
videos = self.app.splitting.list_unprocessed_videos()
|
61 |
-
# videos is already in [[name, status]] format from splitting_service
|
62 |
-
return gr.Dataframe(
|
63 |
-
headers=["name", "status"],
|
64 |
-
value=videos,
|
65 |
-
interactive=False
|
66 |
-
)
|
67 |
-
|
68 |
-
async def start_scene_detection(self, enable_splitting: bool) -> str:
|
69 |
-
"""Start background scene detection process
|
70 |
-
|
71 |
-
Args:
|
72 |
-
enable_splitting: Whether to split videos into scenes
|
73 |
-
"""
|
74 |
-
if self.app.splitting.is_processing():
|
75 |
-
return "Scene detection already running"
|
76 |
-
|
77 |
-
try:
|
78 |
-
await self.app.splitting.start_processing(enable_splitting)
|
79 |
-
return "Scene detection completed"
|
80 |
-
except Exception as e:
|
81 |
-
return f"Error during scene detection: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vms/ui/project/tabs/train_tab.py
CHANGED
@@ -35,7 +35,7 @@ class TrainTab(BaseTab):
|
|
35 |
def __init__(self, app_state):
|
36 |
super().__init__(app_state)
|
37 |
self.id = "train_tab"
|
38 |
-
self.title = "
|
39 |
|
40 |
def create(self, parent=None) -> gr.TabItem:
|
41 |
"""Create the Train tab UI components"""
|
@@ -58,23 +58,46 @@ class TrainTab(BaseTab):
|
|
58 |
with gr.Column():
|
59 |
# Get the default model type from the first preset
|
60 |
default_model_type = list(MODEL_TYPES.keys())[0]
|
61 |
-
|
62 |
self.components["model_type"] = gr.Dropdown(
|
63 |
choices=list(MODEL_TYPES.keys()),
|
64 |
label="Model Type",
|
65 |
value=default_model_type,
|
66 |
interactive=True
|
67 |
)
|
68 |
-
|
69 |
# Get model versions for the default model type
|
70 |
default_model_versions = self.get_model_version_choices(default_model_type)
|
71 |
default_model_version = self.get_default_model_version(default_model_type)
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
self.components["model_version"] = gr.Dropdown(
|
74 |
choices=default_model_versions,
|
75 |
label="Model Version",
|
76 |
value=default_model_version,
|
77 |
-
interactive=True
|
|
|
78 |
)
|
79 |
|
80 |
self.components["training_type"] = gr.Dropdown(
|
@@ -428,15 +451,38 @@ class TrainTab(BaseTab):
|
|
428 |
|
429 |
def update_model_versions(self, model_type: str) -> Dict:
|
430 |
"""Update model version choices based on selected model type"""
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
def handle_training_start(
|
441 |
self, preset, model_type, model_version, training_type,
|
442 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
@@ -509,26 +555,30 @@ class TrainTab(BaseTab):
|
|
509 |
# Convert UI display name to internal name
|
510 |
internal_type = MODEL_TYPES.get(model_type)
|
511 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
|
|
512 |
return []
|
513 |
|
514 |
-
# Return just the model IDs
|
515 |
-
|
|
|
|
|
516 |
|
517 |
-
|
518 |
def get_default_model_version(self, model_type: str) -> str:
|
519 |
"""Get default model version for the given model type"""
|
520 |
# Convert UI display name to internal name
|
521 |
internal_type = MODEL_TYPES.get(model_type)
|
522 |
-
|
|
|
523 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
|
|
524 |
return ""
|
525 |
|
526 |
# Get the first version available for this model type
|
527 |
-
versions = MODEL_VERSIONS.get(internal_type, {})
|
528 |
if versions:
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
return ""
|
533 |
|
534 |
def update_model_info(self, model_type: str, training_type: str) -> Dict:
|
@@ -698,7 +748,19 @@ class TrainTab(BaseTab):
|
|
698 |
# Get the appropriate model version for the selected model type
|
699 |
model_versions = self.get_model_version_choices(model_display_name)
|
700 |
default_model_version = self.get_default_model_version(model_display_name)
|
701 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
702 |
# Return values in the same order as the output components
|
703 |
return (
|
704 |
model_display_name,
|
@@ -714,7 +776,7 @@ class TrainTab(BaseTab):
|
|
714 |
num_gpus_val,
|
715 |
precomputation_items_val,
|
716 |
lr_warmup_steps_val,
|
717 |
-
|
718 |
)
|
719 |
|
720 |
|
|
|
35 |
def __init__(self, app_state):
|
36 |
super().__init__(app_state)
|
37 |
self.id = "train_tab"
|
38 |
+
self.title = "3️⃣ Train"
|
39 |
|
40 |
def create(self, parent=None) -> gr.TabItem:
|
41 |
"""Create the Train tab UI components"""
|
|
|
58 |
with gr.Column():
|
59 |
# Get the default model type from the first preset
|
60 |
default_model_type = list(MODEL_TYPES.keys())[0]
|
61 |
+
|
62 |
self.components["model_type"] = gr.Dropdown(
|
63 |
choices=list(MODEL_TYPES.keys()),
|
64 |
label="Model Type",
|
65 |
value=default_model_type,
|
66 |
interactive=True
|
67 |
)
|
68 |
+
|
69 |
# Get model versions for the default model type
|
70 |
default_model_versions = self.get_model_version_choices(default_model_type)
|
71 |
default_model_version = self.get_default_model_version(default_model_type)
|
72 |
+
|
73 |
+
# Ensure default_model_versions is not empty
|
74 |
+
if not default_model_versions:
|
75 |
+
# If no versions found for default model, use a fallback
|
76 |
+
internal_type = MODEL_TYPES.get(default_model_type)
|
77 |
+
if internal_type in MODEL_VERSIONS:
|
78 |
+
default_model_versions = list(MODEL_VERSIONS[internal_type].keys())
|
79 |
+
else:
|
80 |
+
# Last resort - collect all available versions from all models
|
81 |
+
default_model_versions = []
|
82 |
+
for model_versions in MODEL_VERSIONS.values():
|
83 |
+
default_model_versions.extend(list(model_versions.keys()))
|
84 |
+
|
85 |
+
# If still empty, provide a placeholder
|
86 |
+
if not default_model_versions:
|
87 |
+
default_model_versions = ["-- No versions available --"]
|
88 |
+
|
89 |
+
# Set default version to first in list if available
|
90 |
+
if default_model_versions:
|
91 |
+
default_model_version = default_model_versions[0]
|
92 |
+
else:
|
93 |
+
default_model_version = ""
|
94 |
+
|
95 |
self.components["model_version"] = gr.Dropdown(
|
96 |
choices=default_model_versions,
|
97 |
label="Model Version",
|
98 |
value=default_model_version,
|
99 |
+
interactive=True,
|
100 |
+
allow_custom_value=True # Add this to avoid errors with custom values
|
101 |
)
|
102 |
|
103 |
self.components["training_type"] = gr.Dropdown(
|
|
|
451 |
|
452 |
def update_model_versions(self, model_type: str) -> Dict:
|
453 |
"""Update model version choices based on selected model type"""
|
454 |
+
try:
|
455 |
+
# Get version choices for this model type
|
456 |
+
model_versions = self.get_model_version_choices(model_type)
|
457 |
+
|
458 |
+
# Get default version
|
459 |
+
default_version = self.get_default_model_version(model_type)
|
460 |
+
logger.info(f"update_model_versions({model_type}): default_version = {default_version}, available versions: {model_versions}")
|
461 |
+
|
462 |
+
# Update UI state with proper model_type first
|
463 |
+
self.app.update_ui_state(model_type=model_type)
|
464 |
+
|
465 |
+
# Create a new dropdown with the updated choices
|
466 |
+
if not model_versions:
|
467 |
+
logger.warning(f"No model versions available for {model_type}, using empty list")
|
468 |
+
# Return empty dropdown to avoid errors
|
469 |
+
return gr.Dropdown(choices=[], value=None)
|
470 |
+
|
471 |
+
# Ensure default_version is in model_versions
|
472 |
+
if default_version not in model_versions and model_versions:
|
473 |
+
default_version = model_versions[0]
|
474 |
+
logger.info(f"Default version not in choices, using first available: {default_version}")
|
475 |
+
|
476 |
+
# Return the updated dropdown
|
477 |
+
logger.info(f"Returning dropdown with {len(model_versions)} choices")
|
478 |
+
return gr.Dropdown(choices=model_versions, value=default_version)
|
479 |
+
except Exception as e:
|
480 |
+
# Log any exceptions for debugging
|
481 |
+
logger.error(f"Error in update_model_versions: {str(e)}")
|
482 |
+
# Return empty dropdown to avoid errors
|
483 |
+
return gr.Dropdown(choices=[], value=None)
|
484 |
+
|
485 |
+
|
486 |
def handle_training_start(
|
487 |
self, preset, model_type, model_version, training_type,
|
488 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
|
|
555 |
# Convert UI display name to internal name
|
556 |
internal_type = MODEL_TYPES.get(model_type)
|
557 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
558 |
+
logger.warning(f"No model versions found for {model_type} (internal type: {internal_type})")
|
559 |
return []
|
560 |
|
561 |
+
# Return just the model IDs as a list of simple strings
|
562 |
+
version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys())
|
563 |
+
logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}")
|
564 |
+
return version_ids
|
565 |
|
|
|
566 |
def get_default_model_version(self, model_type: str) -> str:
|
567 |
"""Get default model version for the given model type"""
|
568 |
# Convert UI display name to internal name
|
569 |
internal_type = MODEL_TYPES.get(model_type)
|
570 |
+
logger.debug(f"get_default_model_version({model_type}) = {internal_type}")
|
571 |
+
|
572 |
if not internal_type or internal_type not in MODEL_VERSIONS:
|
573 |
+
logger.warning(f"No valid model versions found for {model_type}")
|
574 |
return ""
|
575 |
|
576 |
# Get the first version available for this model type
|
577 |
+
versions = list(MODEL_VERSIONS.get(internal_type, {}).keys())
|
578 |
if versions:
|
579 |
+
default_version = versions[0]
|
580 |
+
logger.debug(f"Default version for {model_type}: {default_version}")
|
581 |
+
return default_version
|
582 |
return ""
|
583 |
|
584 |
def update_model_info(self, model_type: str, training_type: str) -> Dict:
|
|
|
748 |
# Get the appropriate model version for the selected model type
|
749 |
model_versions = self.get_model_version_choices(model_display_name)
|
750 |
default_model_version = self.get_default_model_version(model_display_name)
|
751 |
+
|
752 |
+
# Create the model version dropdown update
|
753 |
+
model_version_update = gr.Dropdown(choices=model_versions, value=default_model_version)
|
754 |
+
|
755 |
+
# Ensure we have valid choices and values
|
756 |
+
if not model_versions:
|
757 |
+
logger.warning(f"No versions found for {model_display_name}, using empty list")
|
758 |
+
model_versions = []
|
759 |
+
default_model_version = None
|
760 |
+
elif default_model_version not in model_versions and model_versions:
|
761 |
+
default_model_version = model_versions[0]
|
762 |
+
logger.info(f"Reset default version to first available: {default_model_version}")
|
763 |
+
|
764 |
# Return values in the same order as the output components
|
765 |
return (
|
766 |
model_display_name,
|
|
|
776 |
num_gpus_val,
|
777 |
precomputation_items_val,
|
778 |
lr_warmup_steps_val,
|
779 |
+
model_version_update,
|
780 |
)
|
781 |
|
782 |
|