Spaces:
Running
Running
Commit
·
b613c3c
1
Parent(s):
ecd5028
work on basic monitor (no gpu for now)
Browse files- README.md +29 -2
- app.py +4 -1
- requirements.txt +5 -1
- requirements_without_flash_attention.txt +6 -2
- run.sh +9 -1
- setup_no_captions.sh +10 -2
- vms/services/__init__.py +2 -0
- vms/services/importer/__init__.py +11 -0
- vms/services/{importer.py → importer/file_upload.py} +43 -57
- vms/services/importer/hub_dataset.py +521 -0
- vms/services/importer/import_service.py +102 -0
- vms/services/importer/youtube.py +73 -0
- vms/services/monitoring.py +361 -0
- vms/tabs/__init__.py +2 -0
- vms/tabs/import_tab/__init__.py +10 -0
- vms/tabs/import_tab/hub_tab.py +273 -0
- vms/tabs/{import_tab.py → import_tab/import_tab.py} +40 -80
- vms/tabs/import_tab/upload_tab.py +74 -0
- vms/tabs/import_tab/youtube_tab.py +67 -0
- vms/tabs/manage_tab.py +1 -1
- vms/tabs/monitor_tab.py +407 -0
- vms/ui/video_trainer_ui.py +8 -3
README.md
CHANGED
@@ -120,6 +120,33 @@ As this is not automatic, then click on "Restart" in the space dev mode UI widge
|
|
120 |
|
121 |
I haven't tested it, but you can try to provided Dockerfile
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
### Full installation in local
|
124 |
|
125 |
the full installation requires:
|
@@ -127,7 +154,7 @@ the full installation requires:
|
|
127 |
- CUDA 12
|
128 |
- Python 3.10
|
129 |
|
130 |
-
This is because of flash attention, which is defined in the `requirements.txt` using an URL to download a prebuilt wheel (python bindings for a native library)
|
131 |
|
132 |
```bash
|
133 |
./setup.sh
|
@@ -153,7 +180,7 @@ Here is how to do solution 3:
|
|
153 |
Note: please make sure you properly define the environment variables for `STORAGE_PATH` (eg. `/data/`) and `HF_HOME` (eg. `/data/huggingface/`)
|
154 |
|
155 |
```bash
|
156 |
-
|
157 |
```
|
158 |
|
159 |
### Running locally
|
|
|
120 |
|
121 |
I haven't tested it, but you can try to provided Dockerfile
|
122 |
|
123 |
+
### Prerequisites
|
124 |
+
|
125 |
+
About Python:
|
126 |
+
|
127 |
+
I haven't tested Python 3.11 or 3.12, but I noticed some incompatibilities with Python 3.13 dependencies failing to install.
|
128 |
+
|
129 |
+
So I recommend you to install [pyenv](https://github.com/pyenv/pyenv) to switch between versions of Python.
|
130 |
+
|
131 |
+
If you are on macOS, you might already have some versions of Python installed, you can see them by typing:
|
132 |
+
|
133 |
+
```bash
|
134 |
+
% python3.10 --version
|
135 |
+
Python 3.10.16
|
136 |
+
% python3.11 --version
|
137 |
+
Python 3.11.11
|
138 |
+
% python3.12 --version
|
139 |
+
Python 3.12.9
|
140 |
+
% python3.13 --version
|
141 |
+
Python 3.13.2
|
142 |
+
```
|
143 |
+
|
144 |
+
Once pyenv is installed you can type:
|
145 |
+
|
146 |
+
```bash
|
147 |
+
pyenv install 3.10.16
|
148 |
+
```
|
149 |
+
|
150 |
### Full installation in local
|
151 |
|
152 |
the full installation requires:
|
|
|
154 |
- CUDA 12
|
155 |
- Python 3.10
|
156 |
|
157 |
+
This is because of flash attention, which is defined in the `requirements.txt` using an URL to download a prebuilt wheel expecting this exact configuration (python bindings for a native library)
|
158 |
|
159 |
```bash
|
160 |
./setup.sh
|
|
|
180 |
Note: please make sure you properly define the environment variables for `STORAGE_PATH` (eg. `/data/`) and `HF_HOME` (eg. `/data/huggingface/`)
|
181 |
|
182 |
```bash
|
183 |
+
python3.10 app.py
|
184 |
```
|
185 |
|
186 |
### Running locally
|
app.py
CHANGED
@@ -14,6 +14,7 @@ from vms.config import (
|
|
14 |
OUTPUT_PATH, ASK_USER_TO_DUPLICATE_SPACE,
|
15 |
HF_API_TOKEN
|
16 |
)
|
|
|
17 |
from vms.ui.video_trainer_ui import VideoTrainerUI
|
18 |
|
19 |
# Configure logging
|
@@ -37,7 +38,9 @@ To avoid overpaying for your space, you can configure the auto-sleep settings to
|
|
37 |
|
38 |
# Create the main application UI
|
39 |
ui = VideoTrainerUI()
|
40 |
-
|
|
|
|
|
41 |
|
42 |
def main():
|
43 |
"""Main entry point for the application"""
|
|
|
14 |
OUTPUT_PATH, ASK_USER_TO_DUPLICATE_SPACE,
|
15 |
HF_API_TOKEN
|
16 |
)
|
17 |
+
|
18 |
from vms.ui.video_trainer_ui import VideoTrainerUI
|
19 |
|
20 |
# Configure logging
|
|
|
38 |
|
39 |
# Create the main application UI
|
40 |
ui = VideoTrainerUI()
|
41 |
+
app = ui.create_ui()
|
42 |
+
|
43 |
+
return app
|
44 |
|
45 |
def main():
|
46 |
"""Main entry point for the application"""
|
requirements.txt
CHANGED
@@ -2,6 +2,7 @@ numpy>=1.26.4
|
|
2 |
|
3 |
# to quote a-r-r-o-w/finetrainers:
|
4 |
# It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
|
|
|
5 |
torch==2.5.1
|
6 |
torchvision==0.20.1
|
7 |
torchao==0.6.1
|
@@ -41,4 +42,7 @@ git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
|
41 |
|
42 |
# for our frontend
|
43 |
gradio==5.20.1
|
44 |
-
gradio_toggle
|
|
|
|
|
|
|
|
2 |
|
3 |
# to quote a-r-r-o-w/finetrainers:
|
4 |
# It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
|
5 |
+
# on some system (Python 3.13+) those do not work:
|
6 |
torch==2.5.1
|
7 |
torchvision==0.20.1
|
8 |
torchao==0.6.1
|
|
|
42 |
|
43 |
# for our frontend
|
44 |
gradio==5.20.1
|
45 |
+
gradio_toggle
|
46 |
+
|
47 |
+
# used for the monitor
|
48 |
+
matplotlib
|
requirements_without_flash_attention.txt
CHANGED
@@ -2,11 +2,12 @@ numpy>=1.26.4
|
|
2 |
|
3 |
# to quote a-r-r-o-w/finetrainers:
|
4 |
# It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
|
|
|
|
|
5 |
torch==2.5.1
|
6 |
torchvision==0.20.1
|
7 |
torchao==0.6.1
|
8 |
|
9 |
-
|
10 |
huggingface_hub
|
11 |
hf_transfer>=0.1.8
|
12 |
diffusers @ git+https://github.com/huggingface/diffusers.git@main
|
@@ -40,4 +41,7 @@ git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
|
40 |
|
41 |
# for our frontend
|
42 |
gradio==5.20.1
|
43 |
-
gradio_toggle
|
|
|
|
|
|
|
|
2 |
|
3 |
# to quote a-r-r-o-w/finetrainers:
|
4 |
# It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
|
5 |
+
|
6 |
+
# on some system (Python 3.13+) those do not work:
|
7 |
torch==2.5.1
|
8 |
torchvision==0.20.1
|
9 |
torchao==0.6.1
|
10 |
|
|
|
11 |
huggingface_hub
|
12 |
hf_transfer>=0.1.8
|
13 |
diffusers @ git+https://github.com/huggingface/diffusers.git@main
|
|
|
41 |
|
42 |
# for our frontend
|
43 |
gradio==5.20.1
|
44 |
+
gradio_toggle
|
45 |
+
|
46 |
+
# used for the monitor
|
47 |
+
matplotlib
|
run.sh
CHANGED
@@ -2,4 +2,12 @@
|
|
2 |
|
3 |
source .venv/bin/activate
|
4 |
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
source .venv/bin/activate
|
4 |
|
5 |
+
echo "if run.sh fails due to python being not found, edit run.sh to replace with another version of python"
|
6 |
+
|
7 |
+
# if you are on a mac, you can try to replace "python3.10" with:
|
8 |
+
# python3.10
|
9 |
+
# python3.11 (not tested)
|
10 |
+
# python3.12 (not tested)
|
11 |
+
# python3.13 (tested, fails to install)
|
12 |
+
|
13 |
+
USE_MOCK_CAPTIONING_MODEL=True python3.10 app.py
|
setup_no_captions.sh
CHANGED
@@ -1,10 +1,18 @@
|
|
1 |
#!/usr/bin/env bash
|
2 |
|
3 |
-
python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
source .venv/bin/activate
|
6 |
|
7 |
-
|
8 |
|
9 |
# if you require flash attention, please install it manually for your operating system
|
10 |
|
|
|
1 |
#!/usr/bin/env bash
|
2 |
|
3 |
+
echo "if install fails due to python being not found, edit setup_no_captions.sh to replace with another version of python"
|
4 |
+
|
5 |
+
# if you are on a mac, you can try to replace "python3.10" with:
|
6 |
+
# python3.10
|
7 |
+
# python3.11 (not tested)
|
8 |
+
# python3.12 (not tested)
|
9 |
+
# python3.13 (tested, fails to install)
|
10 |
+
|
11 |
+
python3.10 -m venv .venv
|
12 |
|
13 |
source .venv/bin/activate
|
14 |
|
15 |
+
python3.10 -m pip install -r requirements_without_flash_attention.txt
|
16 |
|
17 |
# if you require flash attention, please install it manually for your operating system
|
18 |
|
vms/services/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from .captioner import CaptioningProgress, CaptioningService
|
2 |
from .importer import ImportService
|
|
|
3 |
from .splitter import SplittingService
|
4 |
from .trainer import TrainingService
|
5 |
|
@@ -7,6 +8,7 @@ __all__ = [
|
|
7 |
'CaptioningProgress',
|
8 |
'CaptioningService',
|
9 |
'ImportService',
|
|
|
10 |
'SplittingService',
|
11 |
'TrainingService',
|
12 |
]
|
|
|
1 |
from .captioner import CaptioningProgress, CaptioningService
|
2 |
from .importer import ImportService
|
3 |
+
from .monitoring import MonitoringService
|
4 |
from .splitter import SplittingService
|
5 |
from .trainer import TrainingService
|
6 |
|
|
|
8 |
'CaptioningProgress',
|
9 |
'CaptioningService',
|
10 |
'ImportService',
|
11 |
+
'MonitoringService',
|
12 |
'SplittingService',
|
13 |
'TrainingService',
|
14 |
]
|
vms/services/importer/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Import module for Video Model Studio.
|
3 |
+
Handles file uploads, YouTube downloads, and Hugging Face Hub dataset integration.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from .import_service import ImportService
|
7 |
+
from .file_upload import FileUploadHandler
|
8 |
+
from .youtube import YouTubeDownloader
|
9 |
+
from .hub_dataset import HubDatasetBrowser
|
10 |
+
|
11 |
+
__all__ = ['ImportService', 'FileUploadHandler', 'YouTubeDownloader', 'HubDatasetBrowser']
|
vms/services/{importer.py → importer/file_upload.py}
RENAMED
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import shutil
|
3 |
import zipfile
|
@@ -5,16 +10,18 @@ import tarfile
|
|
5 |
import tempfile
|
6 |
import gradio as gr
|
7 |
from pathlib import Path
|
8 |
-
from typing import List, Dict, Optional, Tuple
|
9 |
-
from pytubefix import YouTube
|
10 |
import logging
|
|
|
11 |
|
12 |
-
from
|
13 |
-
from
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
-
class
|
|
|
|
|
18 |
def process_uploaded_files(self, file_paths: List[str]) -> str:
|
19 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
20 |
|
@@ -24,11 +31,15 @@ class ImportService:
|
|
24 |
Returns:
|
25 |
Status message string
|
26 |
"""
|
|
|
|
|
|
|
|
|
27 |
for file_path in file_paths:
|
28 |
file_path = Path(file_path)
|
29 |
try:
|
30 |
original_name = file_path.name
|
31 |
-
|
32 |
|
33 |
# Determine file type from name
|
34 |
file_ext = file_path.suffix.lower()
|
@@ -42,9 +53,11 @@ class ImportService:
|
|
42 |
elif is_image_file(file_path):
|
43 |
return self.process_image_file(file_path, original_name)
|
44 |
else:
|
|
|
45 |
raise gr.Error(f"Unsupported file type: {file_ext}")
|
46 |
|
47 |
except Exception as e:
|
|
|
48 |
raise gr.Error(f"Error processing file: {str(e)}")
|
49 |
|
50 |
def process_image_file(self, file_path: Path, original_name: str) -> str:
|
@@ -68,10 +81,13 @@ class ImportService:
|
|
68 |
target_path = STAGING_PATH / f"{stem}___{counter}.{NORMALIZE_IMAGES_TO}"
|
69 |
counter += 1
|
70 |
|
|
|
|
|
71 |
# Convert to normalized format and remove black bars
|
72 |
success = normalize_image(file_path, target_path)
|
73 |
|
74 |
if not success:
|
|
|
75 |
raise gr.Error(f"Failed to process image: {original_name}")
|
76 |
|
77 |
# Handle caption
|
@@ -86,6 +102,7 @@ class ImportService:
|
|
86 |
return f"Successfully stored image: {target_path.name}"
|
87 |
|
88 |
except Exception as e:
|
|
|
89 |
raise gr.Error(f"Error processing image file: {str(e)}")
|
90 |
|
91 |
def process_zip_file(self, file_path: Path) -> str:
|
@@ -102,6 +119,8 @@ class ImportService:
|
|
102 |
image_count = 0
|
103 |
tar_count = 0
|
104 |
|
|
|
|
|
105 |
# Create temporary directory
|
106 |
with tempfile.TemporaryDirectory() as temp_dir:
|
107 |
# Extract ZIP
|
@@ -121,6 +140,7 @@ class ImportService:
|
|
121 |
try:
|
122 |
# Check if it's a WebDataset tar file
|
123 |
if file.lower().endswith('.tar'):
|
|
|
124 |
# Process WebDataset shard
|
125 |
vid_count, img_count = webdataset_handler.process_webdataset_shard(
|
126 |
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
|
@@ -136,6 +156,7 @@ class ImportService:
|
|
136 |
target_path = VIDEOS_TO_SPLIT_PATH / f"{file_path.stem}___{counter}{file_path.suffix}"
|
137 |
counter += 1
|
138 |
shutil.copy2(file_path, target_path)
|
|
|
139 |
video_count += 1
|
140 |
|
141 |
elif is_image_file(file_path):
|
@@ -146,6 +167,7 @@ class ImportService:
|
|
146 |
target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
|
147 |
counter += 1
|
148 |
if normalize_image(file_path, target_path):
|
|
|
149 |
image_count += 1
|
150 |
|
151 |
# Copy associated caption file if it exists
|
@@ -153,13 +175,15 @@ class ImportService:
|
|
153 |
if txt_path.exists() and not file.lower().endswith('.tar'):
|
154 |
if is_video_file(file_path):
|
155 |
shutil.copy2(txt_path, target_path.with_suffix('.txt'))
|
|
|
156 |
elif is_image_file(file_path):
|
157 |
caption = txt_path.read_text()
|
158 |
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
|
159 |
target_path.with_suffix('.txt').write_text(caption)
|
|
|
160 |
|
161 |
except Exception as e:
|
162 |
-
logger.error(f"Error processing {file_path.name}: {str(e)}")
|
163 |
continue
|
164 |
|
165 |
# Generate status message
|
@@ -172,13 +196,16 @@ class ImportService:
|
|
172 |
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
|
173 |
|
174 |
if not parts:
|
|
|
175 |
return "No supported media files found in ZIP"
|
176 |
|
177 |
status = f"Successfully stored {', '.join(parts)}"
|
|
|
178 |
gr.Info(status)
|
179 |
return status
|
180 |
|
181 |
except Exception as e:
|
|
|
182 |
raise gr.Error(f"Error processing ZIP: {str(e)}")
|
183 |
|
184 |
def process_tar_file(self, file_path: Path) -> str:
|
@@ -191,6 +218,7 @@ class ImportService:
|
|
191 |
Status message string
|
192 |
"""
|
193 |
try:
|
|
|
194 |
video_count, image_count = webdataset_handler.process_webdataset_shard(
|
195 |
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
|
196 |
)
|
@@ -203,13 +231,16 @@ class ImportService:
|
|
203 |
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
|
204 |
|
205 |
if not parts:
|
|
|
206 |
return "No supported media files found in WebDataset"
|
207 |
|
208 |
status = f"Successfully extracted {' and '.join(parts)} from WebDataset"
|
|
|
209 |
gr.Info(status)
|
210 |
return status
|
211 |
|
212 |
except Exception as e:
|
|
|
213 |
raise gr.Error(f"Error processing WebDataset tar file: {str(e)}")
|
214 |
|
215 |
def process_mp4_file(self, file_path: Path, original_name: str) -> str:
|
@@ -233,60 +264,15 @@ class ImportService:
|
|
233 |
target_path = VIDEOS_TO_SPLIT_PATH / f"{stem}___{counter}.mp4"
|
234 |
counter += 1
|
235 |
|
|
|
|
|
236 |
# Copy the file to the target location
|
237 |
shutil.copy2(file_path, target_path)
|
238 |
|
|
|
239 |
gr.Info(f"Successfully stored video: {target_path.name}")
|
240 |
return f"Successfully stored video: {target_path.name}"
|
241 |
|
242 |
except Exception as e:
|
243 |
-
|
244 |
-
|
245 |
-
def download_youtube_video(self, url: str, progress=None) -> Dict:
|
246 |
-
"""Download a video from YouTube
|
247 |
-
|
248 |
-
Args:
|
249 |
-
url: YouTube video URL
|
250 |
-
progress: Optional Gradio progress indicator
|
251 |
-
|
252 |
-
Returns:
|
253 |
-
Dict with status message and error (if any)
|
254 |
-
"""
|
255 |
-
try:
|
256 |
-
# Extract video ID and create YouTube object
|
257 |
-
yt = YouTube(url, on_progress_callback=lambda stream, chunk, bytes_remaining:
|
258 |
-
progress((1 - bytes_remaining / stream.filesize), desc="Downloading...")
|
259 |
-
if progress else None)
|
260 |
-
|
261 |
-
video_id = yt.video_id
|
262 |
-
output_path = VIDEOS_TO_SPLIT_PATH / f"{video_id}.mp4"
|
263 |
-
|
264 |
-
# Download highest quality progressive MP4
|
265 |
-
if progress:
|
266 |
-
print("Getting video streams...")
|
267 |
-
progress(0, desc="Getting video streams...")
|
268 |
-
video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
|
269 |
-
|
270 |
-
if not video:
|
271 |
-
print("Could not find a compatible video format")
|
272 |
-
gr.Error("Could not find a compatible video format")
|
273 |
-
return "Could not find a compatible video format"
|
274 |
-
|
275 |
-
# Download the video
|
276 |
-
if progress:
|
277 |
-
print("Starting YouTube video download...")
|
278 |
-
progress(0, desc="Starting download...")
|
279 |
-
|
280 |
-
video.download(output_path=str(VIDEOS_TO_SPLIT_PATH), filename=f"{video_id}.mp4")
|
281 |
-
|
282 |
-
# Update UI
|
283 |
-
if progress:
|
284 |
-
print("YouTube video download complete!")
|
285 |
-
gr.Info("YouTube video download complete!")
|
286 |
-
progress(1, desc="Download complete!")
|
287 |
-
return f"Successfully downloaded video: {yt.title}"
|
288 |
-
|
289 |
-
except Exception as e:
|
290 |
-
print(e)
|
291 |
-
gr.Error(f"Error downloading video: {str(e)}")
|
292 |
-
return f"Error downloading video: {str(e)}"
|
|
|
1 |
+
"""
|
2 |
+
File upload handler for Video Model Studio.
|
3 |
+
Processes uploaded files including videos, images, ZIPs, and WebDataset archives.
|
4 |
+
"""
|
5 |
+
|
6 |
import os
|
7 |
import shutil
|
8 |
import zipfile
|
|
|
10 |
import tempfile
|
11 |
import gradio as gr
|
12 |
from pathlib import Path
|
13 |
+
from typing import List, Dict, Optional, Tuple, Any, Union
|
|
|
14 |
import logging
|
15 |
+
import traceback
|
16 |
|
17 |
+
from vms.config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
|
18 |
+
from vms.utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler
|
19 |
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
+
class FileUploadHandler:
|
23 |
+
"""Handles processing of uploaded files"""
|
24 |
+
|
25 |
def process_uploaded_files(self, file_paths: List[str]) -> str:
|
26 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
27 |
|
|
|
31 |
Returns:
|
32 |
Status message string
|
33 |
"""
|
34 |
+
if not file_paths or len(file_paths) == 0:
|
35 |
+
logger.warning("No files provided to process_uploaded_files")
|
36 |
+
return "No files provided"
|
37 |
+
|
38 |
for file_path in file_paths:
|
39 |
file_path = Path(file_path)
|
40 |
try:
|
41 |
original_name = file_path.name
|
42 |
+
logger.info(f"Processing uploaded file: {original_name}")
|
43 |
|
44 |
# Determine file type from name
|
45 |
file_ext = file_path.suffix.lower()
|
|
|
53 |
elif is_image_file(file_path):
|
54 |
return self.process_image_file(file_path, original_name)
|
55 |
else:
|
56 |
+
logger.error(f"Unsupported file type: {file_ext}")
|
57 |
raise gr.Error(f"Unsupported file type: {file_ext}")
|
58 |
|
59 |
except Exception as e:
|
60 |
+
logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
|
61 |
raise gr.Error(f"Error processing file: {str(e)}")
|
62 |
|
63 |
def process_image_file(self, file_path: Path, original_name: str) -> str:
|
|
|
81 |
target_path = STAGING_PATH / f"{stem}___{counter}.{NORMALIZE_IMAGES_TO}"
|
82 |
counter += 1
|
83 |
|
84 |
+
logger.info(f"Processing image file: {original_name} -> {target_path}")
|
85 |
+
|
86 |
# Convert to normalized format and remove black bars
|
87 |
success = normalize_image(file_path, target_path)
|
88 |
|
89 |
if not success:
|
90 |
+
logger.error(f"Failed to process image: {original_name}")
|
91 |
raise gr.Error(f"Failed to process image: {original_name}")
|
92 |
|
93 |
# Handle caption
|
|
|
102 |
return f"Successfully stored image: {target_path.name}"
|
103 |
|
104 |
except Exception as e:
|
105 |
+
logger.error(f"Error processing image file: {str(e)}", exc_info=True)
|
106 |
raise gr.Error(f"Error processing image file: {str(e)}")
|
107 |
|
108 |
def process_zip_file(self, file_path: Path) -> str:
|
|
|
119 |
image_count = 0
|
120 |
tar_count = 0
|
121 |
|
122 |
+
logger.info(f"Processing ZIP file: {file_path}")
|
123 |
+
|
124 |
# Create temporary directory
|
125 |
with tempfile.TemporaryDirectory() as temp_dir:
|
126 |
# Extract ZIP
|
|
|
140 |
try:
|
141 |
# Check if it's a WebDataset tar file
|
142 |
if file.lower().endswith('.tar'):
|
143 |
+
logger.info(f"Processing WebDataset archive from ZIP: {file}")
|
144 |
# Process WebDataset shard
|
145 |
vid_count, img_count = webdataset_handler.process_webdataset_shard(
|
146 |
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
|
|
|
156 |
target_path = VIDEOS_TO_SPLIT_PATH / f"{file_path.stem}___{counter}{file_path.suffix}"
|
157 |
counter += 1
|
158 |
shutil.copy2(file_path, target_path)
|
159 |
+
logger.info(f"Extracted video from ZIP: {file} -> {target_path.name}")
|
160 |
video_count += 1
|
161 |
|
162 |
elif is_image_file(file_path):
|
|
|
167 |
target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
|
168 |
counter += 1
|
169 |
if normalize_image(file_path, target_path):
|
170 |
+
logger.info(f"Extracted image from ZIP: {file} -> {target_path.name}")
|
171 |
image_count += 1
|
172 |
|
173 |
# Copy associated caption file if it exists
|
|
|
175 |
if txt_path.exists() and not file.lower().endswith('.tar'):
|
176 |
if is_video_file(file_path):
|
177 |
shutil.copy2(txt_path, target_path.with_suffix('.txt'))
|
178 |
+
logger.info(f"Copied caption file for {file}")
|
179 |
elif is_image_file(file_path):
|
180 |
caption = txt_path.read_text()
|
181 |
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
|
182 |
target_path.with_suffix('.txt').write_text(caption)
|
183 |
+
logger.info(f"Processed caption for {file}")
|
184 |
|
185 |
except Exception as e:
|
186 |
+
logger.error(f"Error processing {file_path.name} from ZIP: {str(e)}", exc_info=True)
|
187 |
continue
|
188 |
|
189 |
# Generate status message
|
|
|
196 |
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
|
197 |
|
198 |
if not parts:
|
199 |
+
logger.warning("No supported media files found in ZIP")
|
200 |
return "No supported media files found in ZIP"
|
201 |
|
202 |
status = f"Successfully stored {', '.join(parts)}"
|
203 |
+
logger.info(status)
|
204 |
gr.Info(status)
|
205 |
return status
|
206 |
|
207 |
except Exception as e:
|
208 |
+
logger.error(f"Error processing ZIP: {str(e)}", exc_info=True)
|
209 |
raise gr.Error(f"Error processing ZIP: {str(e)}")
|
210 |
|
211 |
def process_tar_file(self, file_path: Path) -> str:
|
|
|
218 |
Status message string
|
219 |
"""
|
220 |
try:
|
221 |
+
logger.info(f"Processing WebDataset TAR file: {file_path}")
|
222 |
video_count, image_count = webdataset_handler.process_webdataset_shard(
|
223 |
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
|
224 |
)
|
|
|
231 |
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
|
232 |
|
233 |
if not parts:
|
234 |
+
logger.warning("No supported media files found in WebDataset")
|
235 |
return "No supported media files found in WebDataset"
|
236 |
|
237 |
status = f"Successfully extracted {' and '.join(parts)} from WebDataset"
|
238 |
+
logger.info(status)
|
239 |
gr.Info(status)
|
240 |
return status
|
241 |
|
242 |
except Exception as e:
|
243 |
+
logger.error(f"Error processing WebDataset tar file: {str(e)}", exc_info=True)
|
244 |
raise gr.Error(f"Error processing WebDataset tar file: {str(e)}")
|
245 |
|
246 |
def process_mp4_file(self, file_path: Path, original_name: str) -> str:
|
|
|
264 |
target_path = VIDEOS_TO_SPLIT_PATH / f"{stem}___{counter}.mp4"
|
265 |
counter += 1
|
266 |
|
267 |
+
logger.info(f"Processing video file: {original_name} -> {target_path}")
|
268 |
+
|
269 |
# Copy the file to the target location
|
270 |
shutil.copy2(file_path, target_path)
|
271 |
|
272 |
+
logger.info(f"Successfully stored video: {target_path.name}")
|
273 |
gr.Info(f"Successfully stored video: {target_path.name}")
|
274 |
return f"Successfully stored video: {target_path.name}"
|
275 |
|
276 |
except Exception as e:
|
277 |
+
logger.error(f"Error processing video file: {str(e)}", exc_info=True)
|
278 |
+
raise gr.Error(f"Error processing video file: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vms/services/importer/hub_dataset.py
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Hugging Face Hub dataset browser for Video Model Studio.
|
3 |
+
Handles searching, viewing, and downloading datasets from the Hub.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
+
import tempfile
|
9 |
+
import asyncio
|
10 |
+
import logging
|
11 |
+
import gradio as gr
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import List, Dict, Optional, Tuple, Any, Union
|
14 |
+
|
15 |
+
from huggingface_hub import (
|
16 |
+
HfApi,
|
17 |
+
hf_hub_download,
|
18 |
+
snapshot_download,
|
19 |
+
list_datasets
|
20 |
+
)
|
21 |
+
|
22 |
+
from vms.config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
|
23 |
+
from vms.utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
class HubDatasetBrowser:
|
28 |
+
"""Handles interactions with Hugging Face Hub datasets"""
|
29 |
+
|
30 |
+
def __init__(self, hf_api: HfApi):
|
31 |
+
"""Initialize with HfApi instance
|
32 |
+
|
33 |
+
Args:
|
34 |
+
hf_api: Hugging Face Hub API instance
|
35 |
+
"""
|
36 |
+
self.hf_api = hf_api
|
37 |
+
|
38 |
+
def search_datasets(self, query: str) -> List[List[str]]:
|
39 |
+
"""Search for datasets on the Hugging Face Hub
|
40 |
+
|
41 |
+
Args:
|
42 |
+
query: Search query string
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
List of datasets matching the query [id, title, downloads]
|
46 |
+
"""
|
47 |
+
try:
|
48 |
+
# Start with some filters to find video-related datasets
|
49 |
+
search_terms = query.strip() if query and query.strip() else "video"
|
50 |
+
logger.info(f"Searching datasets with query: '{search_terms}'")
|
51 |
+
|
52 |
+
# Fetch datasets that match the search
|
53 |
+
datasets = list(self.hf_api.list_datasets(
|
54 |
+
search=search_terms,
|
55 |
+
limit=50
|
56 |
+
))
|
57 |
+
|
58 |
+
# Format results for display
|
59 |
+
results = []
|
60 |
+
for ds in datasets:
|
61 |
+
# Extract relevant information
|
62 |
+
dataset_id = ds.id
|
63 |
+
|
64 |
+
# Safely get the title with fallbacks
|
65 |
+
card_data = getattr(ds, "card_data", None)
|
66 |
+
title = ""
|
67 |
+
|
68 |
+
if card_data is not None and isinstance(card_data, dict):
|
69 |
+
title = card_data.get("name", "")
|
70 |
+
|
71 |
+
if not title:
|
72 |
+
# Use the last part of the repo ID as a fallback
|
73 |
+
title = dataset_id.split("/")[-1]
|
74 |
+
|
75 |
+
# Safely get downloads
|
76 |
+
downloads = getattr(ds, "downloads", 0)
|
77 |
+
if downloads is None:
|
78 |
+
downloads = 0
|
79 |
+
|
80 |
+
results.append([dataset_id, title, downloads])
|
81 |
+
|
82 |
+
# Sort by downloads (most downloaded first)
|
83 |
+
results.sort(key=lambda x: x[2] if x[2] is not None else 0, reverse=True)
|
84 |
+
|
85 |
+
logger.info(f"Found {len(results)} datasets matching '{search_terms}'")
|
86 |
+
return results
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
logger.error(f"Error searching datasets: {str(e)}", exc_info=True)
|
90 |
+
return [[f"Error: {str(e)}", "", ""]]
|
91 |
+
|
92 |
+
def get_dataset_info(self, dataset_id: str) -> Tuple[str, Dict[str, int], Dict[str, List[str]]]:
|
93 |
+
"""Get detailed information about a dataset
|
94 |
+
|
95 |
+
Args:
|
96 |
+
dataset_id: The dataset ID to get information for
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Tuple of (markdown_info, file_counts, file_groups)
|
100 |
+
- markdown_info: Markdown formatted string with dataset information
|
101 |
+
- file_counts: Dictionary with counts of each file type
|
102 |
+
- file_groups: Dictionary with lists of filenames grouped by type
|
103 |
+
"""
|
104 |
+
try:
|
105 |
+
if not dataset_id:
|
106 |
+
logger.warning("No dataset ID provided to get_dataset_info")
|
107 |
+
return "No dataset selected", {}, {}
|
108 |
+
|
109 |
+
logger.info(f"Getting info for dataset: {dataset_id}")
|
110 |
+
|
111 |
+
# Get detailed information about the dataset
|
112 |
+
dataset_info = self.hf_api.dataset_info(dataset_id)
|
113 |
+
|
114 |
+
# Format the information for display
|
115 |
+
info_text = f"## {dataset_info.id}\n\n"
|
116 |
+
|
117 |
+
# Add description if available (with safer access)
|
118 |
+
card_data = getattr(dataset_info, "card_data", None)
|
119 |
+
description = ""
|
120 |
+
|
121 |
+
if card_data is not None and isinstance(card_data, dict):
|
122 |
+
description = card_data.get("description", "")
|
123 |
+
|
124 |
+
if description:
|
125 |
+
info_text += f"{description[:500]}{'...' if len(description) > 500 else ''}\n\n"
|
126 |
+
|
127 |
+
# Add basic stats (with safer access)
|
128 |
+
downloads = getattr(dataset_info, 'downloads', None)
|
129 |
+
info_text += f"**Downloads:** {downloads if downloads is not None else 'N/A'}\n"
|
130 |
+
|
131 |
+
last_modified = getattr(dataset_info, 'last_modified', None)
|
132 |
+
info_text += f"**Last modified:** {last_modified if last_modified is not None else 'N/A'}\n"
|
133 |
+
|
134 |
+
# Show tags if available (with safer access)
|
135 |
+
tags = getattr(dataset_info, "tags", None) or []
|
136 |
+
if tags:
|
137 |
+
info_text += f"**Tags:** {', '.join(tags[:10])}\n\n"
|
138 |
+
|
139 |
+
# Group files by type
|
140 |
+
file_groups = {
|
141 |
+
"video": [],
|
142 |
+
"webdataset": []
|
143 |
+
}
|
144 |
+
|
145 |
+
siblings = getattr(dataset_info, "siblings", None) or []
|
146 |
+
|
147 |
+
# Extract files by type
|
148 |
+
for s in siblings:
|
149 |
+
if not hasattr(s, 'rfilename'):
|
150 |
+
continue
|
151 |
+
|
152 |
+
filename = s.rfilename
|
153 |
+
if filename.lower().endswith((".mp4", ".webm")):
|
154 |
+
file_groups["video"].append(filename)
|
155 |
+
elif filename.lower().endswith(".tar"):
|
156 |
+
file_groups["webdataset"].append(filename)
|
157 |
+
|
158 |
+
# Create file counts dictionary
|
159 |
+
file_counts = {
|
160 |
+
"video": len(file_groups["video"]),
|
161 |
+
"webdataset": len(file_groups["webdataset"])
|
162 |
+
}
|
163 |
+
|
164 |
+
logger.info(f"Successfully retrieved info for dataset: {dataset_id}")
|
165 |
+
return info_text, file_counts, file_groups
|
166 |
+
|
167 |
+
except Exception as e:
|
168 |
+
logger.error(f"Error getting dataset info: {str(e)}", exc_info=True)
|
169 |
+
return f"Error loading dataset information: {str(e)}", {}, {}
|
170 |
+
|
171 |
+
async def download_file_group(self, dataset_id: str, file_type: str, enable_splitting: bool = True) -> str:
|
172 |
+
"""Download all files of a specific type from the dataset
|
173 |
+
|
174 |
+
Args:
|
175 |
+
dataset_id: The dataset ID
|
176 |
+
file_type: Either "video" or "webdataset"
|
177 |
+
enable_splitting: Whether to enable automatic video splitting
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
Status message
|
181 |
+
"""
|
182 |
+
try:
|
183 |
+
# Get dataset info to retrieve file list
|
184 |
+
_, _, file_groups = self.get_dataset_info(dataset_id)
|
185 |
+
|
186 |
+
# Get the list of files for the specified type
|
187 |
+
files = file_groups.get(file_type, [])
|
188 |
+
|
189 |
+
if not files:
|
190 |
+
return f"No {file_type} files found in the dataset"
|
191 |
+
|
192 |
+
logger.info(f"Downloading {len(files)} {file_type} files from dataset {dataset_id}")
|
193 |
+
|
194 |
+
# Track counts for status message
|
195 |
+
video_count = 0
|
196 |
+
image_count = 0
|
197 |
+
|
198 |
+
# Create a temporary directory for downloads
|
199 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
200 |
+
temp_path = Path(temp_dir)
|
201 |
+
|
202 |
+
# Process all files of the requested type
|
203 |
+
for filename in files:
|
204 |
+
try:
|
205 |
+
# Download the file
|
206 |
+
file_path = hf_hub_download(
|
207 |
+
repo_id=dataset_id,
|
208 |
+
filename=filename,
|
209 |
+
repo_type="dataset",
|
210 |
+
local_dir=temp_path
|
211 |
+
)
|
212 |
+
|
213 |
+
file_path = Path(file_path)
|
214 |
+
logger.info(f"Downloaded file to {file_path}")
|
215 |
+
|
216 |
+
# Process based on file type
|
217 |
+
if file_type == "video":
|
218 |
+
# Choose target directory based on auto-splitting setting
|
219 |
+
target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
|
220 |
+
target_path = target_dir / file_path.name
|
221 |
+
|
222 |
+
# Make sure filename is unique
|
223 |
+
counter = 1
|
224 |
+
while target_path.exists():
|
225 |
+
stem = Path(file_path.name).stem
|
226 |
+
if "___" in stem:
|
227 |
+
base_stem = stem.split("___")[0]
|
228 |
+
else:
|
229 |
+
base_stem = stem
|
230 |
+
target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}"
|
231 |
+
counter += 1
|
232 |
+
|
233 |
+
# Copy the video file
|
234 |
+
shutil.copy2(file_path, target_path)
|
235 |
+
logger.info(f"Processed video: {file_path.name} -> {target_path.name}")
|
236 |
+
|
237 |
+
# Try to download caption if it exists
|
238 |
+
try:
|
239 |
+
txt_filename = f"{Path(filename).stem}.txt"
|
240 |
+
for possible_path in [
|
241 |
+
Path(filename).with_suffix('.txt').as_posix(),
|
242 |
+
(Path(filename).parent / txt_filename).as_posix(),
|
243 |
+
]:
|
244 |
+
try:
|
245 |
+
txt_path = hf_hub_download(
|
246 |
+
repo_id=dataset_id,
|
247 |
+
filename=possible_path,
|
248 |
+
repo_type="dataset",
|
249 |
+
local_dir=temp_path
|
250 |
+
)
|
251 |
+
shutil.copy2(txt_path, target_path.with_suffix('.txt'))
|
252 |
+
logger.info(f"Copied caption for {file_path.name}")
|
253 |
+
break
|
254 |
+
except Exception:
|
255 |
+
# Caption file doesn't exist at this path, try next
|
256 |
+
pass
|
257 |
+
except Exception as e:
|
258 |
+
logger.warning(f"Error trying to download caption: {e}")
|
259 |
+
|
260 |
+
video_count += 1
|
261 |
+
|
262 |
+
elif file_type == "webdataset":
|
263 |
+
# Process the WebDataset archive
|
264 |
+
try:
|
265 |
+
logger.info(f"Processing WebDataset file: {file_path}")
|
266 |
+
vid_count, img_count = webdataset_handler.process_webdataset_shard(
|
267 |
+
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
|
268 |
+
)
|
269 |
+
video_count += vid_count
|
270 |
+
image_count += img_count
|
271 |
+
except Exception as e:
|
272 |
+
logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
|
273 |
+
|
274 |
+
except Exception as e:
|
275 |
+
logger.warning(f"Error processing file {filename}: {e}")
|
276 |
+
|
277 |
+
# Generate status message
|
278 |
+
if file_type == "video":
|
279 |
+
return f"Successfully imported {video_count} videos from dataset {dataset_id}"
|
280 |
+
elif file_type == "webdataset":
|
281 |
+
parts = []
|
282 |
+
if video_count > 0:
|
283 |
+
parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
|
284 |
+
if image_count > 0:
|
285 |
+
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
|
286 |
+
|
287 |
+
if parts:
|
288 |
+
return f"Successfully imported {' and '.join(parts)} from WebDataset archives"
|
289 |
+
else:
|
290 |
+
return f"No media was found in the WebDataset archives"
|
291 |
+
|
292 |
+
return f"Unknown file type: {file_type}"
|
293 |
+
|
294 |
+
except Exception as e:
|
295 |
+
error_msg = f"Error downloading {file_type} files: {str(e)}"
|
296 |
+
logger.error(error_msg, exc_info=True)
|
297 |
+
return error_msg
|
298 |
+
|
299 |
+
async def download_dataset(self, dataset_id: str, enable_splitting: bool = True) -> Tuple[str, str]:
|
300 |
+
"""Download a dataset and process its video/image content
|
301 |
+
|
302 |
+
Args:
|
303 |
+
dataset_id: The dataset ID to download
|
304 |
+
enable_splitting: Whether to enable automatic video splitting
|
305 |
+
|
306 |
+
Returns:
|
307 |
+
Tuple of (loading_msg, status_msg)
|
308 |
+
"""
|
309 |
+
if not dataset_id:
|
310 |
+
logger.warning("No dataset ID provided for download")
|
311 |
+
return "No dataset selected", "Please select a dataset first"
|
312 |
+
|
313 |
+
try:
|
314 |
+
logger.info(f"Starting download of dataset: {dataset_id}")
|
315 |
+
loading_msg = f"## Downloading dataset: {dataset_id}\n\nThis may take some time depending on the dataset size..."
|
316 |
+
status_msg = f"Downloading dataset: {dataset_id}..."
|
317 |
+
|
318 |
+
# Get dataset info to check for available files
|
319 |
+
dataset_info = self.hf_api.dataset_info(dataset_id)
|
320 |
+
|
321 |
+
# Check if there are video files or WebDataset files
|
322 |
+
video_files = []
|
323 |
+
tar_files = []
|
324 |
+
|
325 |
+
siblings = getattr(dataset_info, "siblings", None) or []
|
326 |
+
if siblings:
|
327 |
+
video_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith((".mp4", ".webm"))]
|
328 |
+
tar_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith(".tar")]
|
329 |
+
|
330 |
+
# Create a temporary directory for downloads
|
331 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
332 |
+
temp_path = Path(temp_dir)
|
333 |
+
|
334 |
+
# If we have video files, download them individually
|
335 |
+
if video_files:
|
336 |
+
loading_msg = f"{loading_msg}\n\nDownloading {len(video_files)} video files..."
|
337 |
+
logger.info(f"Downloading {len(video_files)} video files from {dataset_id}")
|
338 |
+
|
339 |
+
for i, video_file in enumerate(video_files):
|
340 |
+
# Download the video file
|
341 |
+
try:
|
342 |
+
file_path = hf_hub_download(
|
343 |
+
repo_id=dataset_id,
|
344 |
+
filename=video_file,
|
345 |
+
repo_type="dataset",
|
346 |
+
local_dir=temp_path
|
347 |
+
)
|
348 |
+
|
349 |
+
# Look for associated caption file
|
350 |
+
txt_filename = f"{Path(video_file).stem}.txt"
|
351 |
+
txt_path = None
|
352 |
+
for possible_path in [
|
353 |
+
Path(video_file).with_suffix('.txt').as_posix(),
|
354 |
+
(Path(video_file).parent / txt_filename).as_posix(),
|
355 |
+
]:
|
356 |
+
try:
|
357 |
+
txt_path = hf_hub_download(
|
358 |
+
repo_id=dataset_id,
|
359 |
+
filename=possible_path,
|
360 |
+
repo_type="dataset",
|
361 |
+
local_dir=temp_path
|
362 |
+
)
|
363 |
+
logger.info(f"Found caption file for {video_file}: {possible_path}")
|
364 |
+
break
|
365 |
+
except Exception as e:
|
366 |
+
# Caption file doesn't exist at this path, try next
|
367 |
+
logger.debug(f"No caption at {possible_path}: {str(e)}")
|
368 |
+
pass
|
369 |
+
|
370 |
+
status_msg = f"Downloaded video {i+1}/{len(video_files)} from {dataset_id}"
|
371 |
+
logger.info(status_msg)
|
372 |
+
except Exception as e:
|
373 |
+
logger.warning(f"Error downloading {video_file}: {e}")
|
374 |
+
|
375 |
+
# If we have tar files, download them
|
376 |
+
if tar_files:
|
377 |
+
loading_msg = f"{loading_msg}\n\nDownloading {len(tar_files)} WebDataset files..."
|
378 |
+
logger.info(f"Downloading {len(tar_files)} WebDataset files from {dataset_id}")
|
379 |
+
|
380 |
+
for i, tar_file in enumerate(tar_files):
|
381 |
+
try:
|
382 |
+
file_path = hf_hub_download(
|
383 |
+
repo_id=dataset_id,
|
384 |
+
filename=tar_file,
|
385 |
+
repo_type="dataset",
|
386 |
+
local_dir=temp_path
|
387 |
+
)
|
388 |
+
status_msg = f"Downloaded WebDataset {i+1}/{len(tar_files)} from {dataset_id}"
|
389 |
+
logger.info(status_msg)
|
390 |
+
except Exception as e:
|
391 |
+
logger.warning(f"Error downloading {tar_file}: {e}")
|
392 |
+
|
393 |
+
# If no specific files were found, try downloading the entire repo
|
394 |
+
if not video_files and not tar_files:
|
395 |
+
loading_msg = f"{loading_msg}\n\nDownloading entire dataset repository..."
|
396 |
+
logger.info(f"No specific media files found, downloading entire repository for {dataset_id}")
|
397 |
+
|
398 |
+
try:
|
399 |
+
snapshot_download(
|
400 |
+
repo_id=dataset_id,
|
401 |
+
repo_type="dataset",
|
402 |
+
local_dir=temp_path
|
403 |
+
)
|
404 |
+
status_msg = f"Downloaded entire repository for {dataset_id}"
|
405 |
+
logger.info(status_msg)
|
406 |
+
except Exception as e:
|
407 |
+
logger.error(f"Error downloading dataset snapshot: {e}", exc_info=True)
|
408 |
+
return loading_msg, f"Error downloading dataset: {str(e)}"
|
409 |
+
|
410 |
+
# Process the downloaded files
|
411 |
+
loading_msg = f"{loading_msg}\n\nProcessing downloaded files..."
|
412 |
+
logger.info(f"Processing downloaded files from {dataset_id}")
|
413 |
+
|
414 |
+
# Count imported files
|
415 |
+
video_count = 0
|
416 |
+
image_count = 0
|
417 |
+
tar_count = 0
|
418 |
+
|
419 |
+
# Process function for the event loop
|
420 |
+
async def process_files():
|
421 |
+
nonlocal video_count, image_count, tar_count
|
422 |
+
|
423 |
+
# Process all files in the temp directory
|
424 |
+
for root, _, files in os.walk(temp_path):
|
425 |
+
for file in files:
|
426 |
+
file_path = Path(root) / file
|
427 |
+
|
428 |
+
# Process videos
|
429 |
+
if file.lower().endswith((".mp4", ".webm")):
|
430 |
+
# Choose target path based on auto-splitting setting
|
431 |
+
target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
|
432 |
+
target_path = target_dir / file_path.name
|
433 |
+
|
434 |
+
# Make sure filename is unique
|
435 |
+
counter = 1
|
436 |
+
while target_path.exists():
|
437 |
+
stem = Path(file_path.name).stem
|
438 |
+
if "___" in stem:
|
439 |
+
base_stem = stem.split("___")[0]
|
440 |
+
else:
|
441 |
+
base_stem = stem
|
442 |
+
target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}"
|
443 |
+
counter += 1
|
444 |
+
|
445 |
+
# Copy the video file
|
446 |
+
shutil.copy2(file_path, target_path)
|
447 |
+
logger.info(f"Processed video from dataset: {file_path.name} -> {target_path.name}")
|
448 |
+
|
449 |
+
# Copy associated caption file if it exists
|
450 |
+
txt_path = file_path.with_suffix('.txt')
|
451 |
+
if txt_path.exists():
|
452 |
+
shutil.copy2(txt_path, target_path.with_suffix('.txt'))
|
453 |
+
logger.info(f"Copied caption for {file_path.name}")
|
454 |
+
|
455 |
+
video_count += 1
|
456 |
+
|
457 |
+
# Process images
|
458 |
+
elif is_image_file(file_path):
|
459 |
+
target_path = STAGING_PATH / f"{file_path.stem}.{NORMALIZE_IMAGES_TO}"
|
460 |
+
|
461 |
+
counter = 1
|
462 |
+
while target_path.exists():
|
463 |
+
target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
|
464 |
+
counter += 1
|
465 |
+
|
466 |
+
if normalize_image(file_path, target_path):
|
467 |
+
logger.info(f"Processed image from dataset: {file_path.name} -> {target_path.name}")
|
468 |
+
|
469 |
+
# Copy caption if available
|
470 |
+
txt_path = file_path.with_suffix('.txt')
|
471 |
+
if txt_path.exists():
|
472 |
+
caption = txt_path.read_text()
|
473 |
+
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
|
474 |
+
target_path.with_suffix('.txt').write_text(caption)
|
475 |
+
logger.info(f"Processed caption for {file_path.name}")
|
476 |
+
|
477 |
+
image_count += 1
|
478 |
+
|
479 |
+
# Process WebDataset files
|
480 |
+
elif file.lower().endswith(".tar"):
|
481 |
+
# Process the WebDataset archive
|
482 |
+
try:
|
483 |
+
logger.info(f"Processing WebDataset file from dataset: {file}")
|
484 |
+
vid_count, img_count = webdataset_handler.process_webdataset_shard(
|
485 |
+
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
|
486 |
+
)
|
487 |
+
tar_count += 1
|
488 |
+
video_count += vid_count
|
489 |
+
image_count += img_count
|
490 |
+
logger.info(f"Extracted {vid_count} videos and {img_count} images from {file}")
|
491 |
+
except Exception as e:
|
492 |
+
logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
|
493 |
+
|
494 |
+
# Run the processing asynchronously
|
495 |
+
await process_files()
|
496 |
+
|
497 |
+
# Generate final status message
|
498 |
+
parts = []
|
499 |
+
if video_count > 0:
|
500 |
+
parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
|
501 |
+
if image_count > 0:
|
502 |
+
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
|
503 |
+
if tar_count > 0:
|
504 |
+
parts.append(f"{tar_count} WebDataset archive{'s' if tar_count != 1 else ''}")
|
505 |
+
|
506 |
+
if parts:
|
507 |
+
status = f"Successfully imported {', '.join(parts)} from dataset {dataset_id}"
|
508 |
+
loading_msg = f"{loading_msg}\n\n✅ Success! {status}"
|
509 |
+
logger.info(status)
|
510 |
+
else:
|
511 |
+
status = f"No supported media files found in dataset {dataset_id}"
|
512 |
+
loading_msg = f"{loading_msg}\n\n⚠️ {status}"
|
513 |
+
logger.warning(status)
|
514 |
+
|
515 |
+
gr.Info(status)
|
516 |
+
return loading_msg, status
|
517 |
+
|
518 |
+
except Exception as e:
|
519 |
+
error_msg = f"Error downloading dataset {dataset_id}: {str(e)}"
|
520 |
+
logger.error(error_msg, exc_info=True)
|
521 |
+
return f"Error: {error_msg}", error_msg
|
vms/services/importer/import_service.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Main Import Service for Video Model Studio.
|
3 |
+
Delegates to specialized handler classes for different import types.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from typing import List, Dict, Optional, Tuple, Any, Union
|
8 |
+
from pathlib import Path
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from huggingface_hub import HfApi
|
12 |
+
|
13 |
+
from .file_upload import FileUploadHandler
|
14 |
+
from .youtube import YouTubeDownloader
|
15 |
+
from .hub_dataset import HubDatasetBrowser
|
16 |
+
from vms.config import HF_API_TOKEN
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
class ImportService:
|
21 |
+
"""Main service class for handling imports from various sources"""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
"""Initialize the import service and handlers"""
|
25 |
+
self.hf_api = HfApi(token=HF_API_TOKEN)
|
26 |
+
self.file_handler = FileUploadHandler()
|
27 |
+
self.youtube_handler = YouTubeDownloader()
|
28 |
+
self.hub_browser = HubDatasetBrowser(self.hf_api)
|
29 |
+
|
30 |
+
def process_uploaded_files(self, file_paths: List[str]) -> str:
|
31 |
+
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
32 |
+
|
33 |
+
Args:
|
34 |
+
file_paths: File paths to the uploaded files from Gradio
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Status message string
|
38 |
+
"""
|
39 |
+
if not file_paths or len(file_paths) == 0:
|
40 |
+
logger.warning("No files provided to process_uploaded_files")
|
41 |
+
return "No files provided"
|
42 |
+
|
43 |
+
return self.file_handler.process_uploaded_files(file_paths)
|
44 |
+
|
45 |
+
def download_youtube_video(self, url: str, progress=None) -> str:
|
46 |
+
"""Download a video from YouTube
|
47 |
+
|
48 |
+
Args:
|
49 |
+
url: YouTube video URL
|
50 |
+
progress: Optional Gradio progress indicator
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Status message string
|
54 |
+
"""
|
55 |
+
return self.youtube_handler.download_video(url, progress)
|
56 |
+
|
57 |
+
def search_datasets(self, query: str) -> List[List[str]]:
|
58 |
+
"""Search for datasets on the Hugging Face Hub
|
59 |
+
|
60 |
+
Args:
|
61 |
+
query: Search query string
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
List of datasets matching the query [id, title, downloads]
|
65 |
+
"""
|
66 |
+
return self.hub_browser.search_datasets(query)
|
67 |
+
|
68 |
+
def get_dataset_info(self, dataset_id: str) -> Tuple[str, Dict[str, int], Dict[str, List[str]]]:
|
69 |
+
"""Get detailed information about a dataset
|
70 |
+
|
71 |
+
Args:
|
72 |
+
dataset_id: The dataset ID to get information for
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Tuple of (markdown_info, file_counts, file_groups)
|
76 |
+
"""
|
77 |
+
return self.hub_browser.get_dataset_info(dataset_id)
|
78 |
+
|
79 |
+
async def download_dataset(self, dataset_id: str, enable_splitting: bool = True) -> Tuple[str, str]:
|
80 |
+
"""Download a dataset and process its video/image content
|
81 |
+
|
82 |
+
Args:
|
83 |
+
dataset_id: The dataset ID to download
|
84 |
+
enable_splitting: Whether to enable automatic video splitting
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
Tuple of (loading_msg, status_msg)
|
88 |
+
"""
|
89 |
+
return await self.hub_browser.download_dataset(dataset_id, enable_splitting)
|
90 |
+
|
91 |
+
async def download_file_group(self, dataset_id: str, file_type: str, enable_splitting: bool = True) -> str:
|
92 |
+
"""Download a group of files (videos or WebDatasets)
|
93 |
+
|
94 |
+
Args:
|
95 |
+
dataset_id: The dataset ID
|
96 |
+
file_type: Type of file ("video" or "webdataset")
|
97 |
+
enable_splitting: Whether to enable automatic video splitting
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
Status message
|
101 |
+
"""
|
102 |
+
return await self.hub_browser.download_file_group(dataset_id, file_type, enable_splitting)
|
vms/services/importer/youtube.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
YouTube downloader for Video Model Studio.
|
3 |
+
Handles downloading videos from YouTube URLs.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import gradio as gr
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Optional, Any, Union, Callable
|
10 |
+
|
11 |
+
from pytubefix import YouTube
|
12 |
+
|
13 |
+
from vms.config import VIDEOS_TO_SPLIT_PATH
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
class YouTubeDownloader:
|
18 |
+
"""Handles downloading videos from YouTube"""
|
19 |
+
|
20 |
+
def download_video(self, url: str, progress: Optional[Callable] = None) -> str:
|
21 |
+
"""Download a video from YouTube
|
22 |
+
|
23 |
+
Args:
|
24 |
+
url: YouTube video URL
|
25 |
+
progress: Optional Gradio progress indicator
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Status message string
|
29 |
+
"""
|
30 |
+
if not url or not url.strip():
|
31 |
+
logger.warning("No YouTube URL provided")
|
32 |
+
return "Please enter a YouTube URL"
|
33 |
+
|
34 |
+
try:
|
35 |
+
logger.info(f"Downloading YouTube video: {url}")
|
36 |
+
|
37 |
+
# Extract video ID and create YouTube object
|
38 |
+
yt = YouTube(url, on_progress_callback=lambda stream, chunk, bytes_remaining:
|
39 |
+
progress((1 - bytes_remaining / stream.filesize), desc="Downloading...")
|
40 |
+
if progress else None)
|
41 |
+
|
42 |
+
video_id = yt.video_id
|
43 |
+
output_path = VIDEOS_TO_SPLIT_PATH / f"{video_id}.mp4"
|
44 |
+
|
45 |
+
# Download highest quality progressive MP4
|
46 |
+
if progress:
|
47 |
+
logger.debug("Getting video streams...")
|
48 |
+
progress(0, desc="Getting video streams...")
|
49 |
+
video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
|
50 |
+
|
51 |
+
if not video:
|
52 |
+
logger.error("Could not find a compatible video format")
|
53 |
+
gr.Error("Could not find a compatible video format")
|
54 |
+
return "Could not find a compatible video format"
|
55 |
+
|
56 |
+
# Download the video
|
57 |
+
if progress:
|
58 |
+
logger.info("Starting YouTube video download...")
|
59 |
+
progress(0, desc="Starting download...")
|
60 |
+
|
61 |
+
video.download(output_path=str(VIDEOS_TO_SPLIT_PATH), filename=f"{video_id}.mp4")
|
62 |
+
|
63 |
+
# Update UI
|
64 |
+
if progress:
|
65 |
+
logger.info("YouTube video download complete!")
|
66 |
+
gr.Info("YouTube video download complete!")
|
67 |
+
progress(1, desc="Download complete!")
|
68 |
+
return f"Successfully downloaded video: {yt.title}"
|
69 |
+
|
70 |
+
except Exception as e:
|
71 |
+
logger.error(f"Error downloading YouTube video: {str(e)}", exc_info=True)
|
72 |
+
gr.Error(f"Error downloading video: {str(e)}")
|
73 |
+
return f"Error downloading video: {str(e)}"
|
vms/services/monitoring.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
System monitoring service for Video Model Studio.
|
3 |
+
Tracks system resources like CPU, memory, and other metrics.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import logging
|
9 |
+
import platform
|
10 |
+
import threading
|
11 |
+
from datetime import datetime, timedelta
|
12 |
+
from collections import deque
|
13 |
+
from typing import Dict, List, Optional, Tuple, Any
|
14 |
+
|
15 |
+
import psutil
|
16 |
+
|
17 |
+
# Force the use of the Agg backend which is thread-safe
|
18 |
+
import matplotlib
|
19 |
+
matplotlib.use('Agg') # Must be before importing pyplot
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
class MonitoringService:
|
27 |
+
"""Service for monitoring system resources and performance"""
|
28 |
+
|
29 |
+
def __init__(self, history_minutes: int = 10, sample_interval: int = 5):
|
30 |
+
"""Initialize the monitoring service
|
31 |
+
|
32 |
+
Args:
|
33 |
+
history_minutes: How many minutes of history to keep
|
34 |
+
sample_interval: How many seconds between samples
|
35 |
+
"""
|
36 |
+
self.history_minutes = history_minutes
|
37 |
+
self.sample_interval = sample_interval
|
38 |
+
self.max_samples = (history_minutes * 60) // sample_interval
|
39 |
+
|
40 |
+
# Initialize data structures for metrics
|
41 |
+
self.timestamps = deque(maxlen=self.max_samples)
|
42 |
+
self.cpu_percent = deque(maxlen=self.max_samples)
|
43 |
+
self.memory_percent = deque(maxlen=self.max_samples)
|
44 |
+
self.memory_used = deque(maxlen=self.max_samples)
|
45 |
+
self.memory_available = deque(maxlen=self.max_samples)
|
46 |
+
|
47 |
+
# CPU temperature history (might not be available on all systems)
|
48 |
+
self.cpu_temp = deque(maxlen=self.max_samples)
|
49 |
+
|
50 |
+
# Per-core CPU history
|
51 |
+
self.cpu_cores_percent = {}
|
52 |
+
|
53 |
+
# Track if the monitoring thread is running
|
54 |
+
self.is_running = False
|
55 |
+
self.thread = None
|
56 |
+
|
57 |
+
# Initialize with current values
|
58 |
+
self.collect_metrics()
|
59 |
+
|
60 |
+
def collect_metrics(self) -> Dict[str, Any]:
|
61 |
+
"""Collect current system metrics
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Dictionary of current metrics
|
65 |
+
"""
|
66 |
+
metrics = {
|
67 |
+
'timestamp': datetime.now(),
|
68 |
+
'cpu_percent': psutil.cpu_percent(interval=0.1),
|
69 |
+
'memory_percent': psutil.virtual_memory().percent,
|
70 |
+
'memory_used': psutil.virtual_memory().used / (1024**3), # GB
|
71 |
+
'memory_available': psutil.virtual_memory().available / (1024**3), # GB
|
72 |
+
'cpu_temp': None,
|
73 |
+
'per_cpu_percent': psutil.cpu_percent(interval=0.1, percpu=True)
|
74 |
+
}
|
75 |
+
|
76 |
+
# Try to get CPU temperature (platform specific)
|
77 |
+
try:
|
78 |
+
if platform.system() == 'Linux':
|
79 |
+
# Try to get temperature from psutil
|
80 |
+
temps = psutil.sensors_temperatures()
|
81 |
+
for name, entries in temps.items():
|
82 |
+
if name.startswith(('coretemp', 'k10temp', 'cpu_thermal')):
|
83 |
+
metrics['cpu_temp'] = entries[0].current
|
84 |
+
break
|
85 |
+
elif platform.system() == 'Darwin': # macOS
|
86 |
+
# On macOS, we could use SMC reader but it requires additional dependencies
|
87 |
+
# Leaving as None for now
|
88 |
+
pass
|
89 |
+
elif platform.system() == 'Windows':
|
90 |
+
# Windows might require WMI, leaving as None for simplicity
|
91 |
+
pass
|
92 |
+
except (AttributeError, KeyError, IndexError, NotImplementedError):
|
93 |
+
# Sensors not available
|
94 |
+
pass
|
95 |
+
|
96 |
+
return metrics
|
97 |
+
|
98 |
+
def update_history(self, metrics: Dict[str, Any]) -> None:
|
99 |
+
"""Update metric history with new values
|
100 |
+
|
101 |
+
Args:
|
102 |
+
metrics: New metrics to add to history
|
103 |
+
"""
|
104 |
+
self.timestamps.append(metrics['timestamp'])
|
105 |
+
self.cpu_percent.append(metrics['cpu_percent'])
|
106 |
+
self.memory_percent.append(metrics['memory_percent'])
|
107 |
+
self.memory_used.append(metrics['memory_used'])
|
108 |
+
self.memory_available.append(metrics['memory_available'])
|
109 |
+
|
110 |
+
if metrics['cpu_temp'] is not None:
|
111 |
+
self.cpu_temp.append(metrics['cpu_temp'])
|
112 |
+
|
113 |
+
# Update per-core CPU metrics
|
114 |
+
for i, percent in enumerate(metrics['per_cpu_percent']):
|
115 |
+
if i not in self.cpu_cores_percent:
|
116 |
+
self.cpu_cores_percent[i] = deque(maxlen=self.max_samples)
|
117 |
+
self.cpu_cores_percent[i].append(percent)
|
118 |
+
|
119 |
+
def start_monitoring(self) -> None:
|
120 |
+
"""Start background thread for collecting metrics"""
|
121 |
+
if self.is_running:
|
122 |
+
logger.warning("Monitoring thread already running")
|
123 |
+
return
|
124 |
+
|
125 |
+
self.is_running = True
|
126 |
+
|
127 |
+
def _monitor_loop():
|
128 |
+
while self.is_running:
|
129 |
+
try:
|
130 |
+
metrics = self.collect_metrics()
|
131 |
+
self.update_history(metrics)
|
132 |
+
time.sleep(self.sample_interval)
|
133 |
+
except Exception as e:
|
134 |
+
logger.error(f"Error in monitoring thread: {str(e)}", exc_info=True)
|
135 |
+
time.sleep(self.sample_interval)
|
136 |
+
|
137 |
+
self.thread = threading.Thread(target=_monitor_loop, daemon=True)
|
138 |
+
self.thread.start()
|
139 |
+
logger.info("System monitoring thread started")
|
140 |
+
|
141 |
+
def stop_monitoring(self) -> None:
|
142 |
+
"""Stop the monitoring thread"""
|
143 |
+
if not self.is_running:
|
144 |
+
return
|
145 |
+
|
146 |
+
self.is_running = False
|
147 |
+
if self.thread:
|
148 |
+
self.thread.join(timeout=1.0)
|
149 |
+
logger.info("System monitoring thread stopped")
|
150 |
+
|
151 |
+
def get_current_metrics(self) -> Dict[str, Any]:
|
152 |
+
"""Get current system metrics
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
Dictionary with current system metrics
|
156 |
+
"""
|
157 |
+
return self.collect_metrics()
|
158 |
+
|
159 |
+
def get_system_info(self) -> Dict[str, Any]:
|
160 |
+
"""Get general system information
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
Dictionary with system details
|
164 |
+
"""
|
165 |
+
cpu_info = {
|
166 |
+
'cores_physical': psutil.cpu_count(logical=False),
|
167 |
+
'cores_logical': psutil.cpu_count(logical=True),
|
168 |
+
'current_frequency': None,
|
169 |
+
'architecture': platform.machine(),
|
170 |
+
}
|
171 |
+
|
172 |
+
# Try to get CPU frequency
|
173 |
+
try:
|
174 |
+
cpu_freq = psutil.cpu_freq()
|
175 |
+
if cpu_freq:
|
176 |
+
cpu_info['current_frequency'] = cpu_freq.current
|
177 |
+
except Exception:
|
178 |
+
pass
|
179 |
+
|
180 |
+
memory_info = {
|
181 |
+
'total': psutil.virtual_memory().total / (1024**3), # GB
|
182 |
+
'available': psutil.virtual_memory().available / (1024**3), # GB
|
183 |
+
'used': psutil.virtual_memory().used / (1024**3), # GB
|
184 |
+
'percent': psutil.virtual_memory().percent
|
185 |
+
}
|
186 |
+
|
187 |
+
disk_info = {}
|
188 |
+
for part in psutil.disk_partitions(all=False):
|
189 |
+
if os.name == 'nt' and ('cdrom' in part.opts or part.fstype == ''):
|
190 |
+
# Skip CD-ROM drives on Windows
|
191 |
+
continue
|
192 |
+
try:
|
193 |
+
usage = psutil.disk_usage(part.mountpoint)
|
194 |
+
disk_info[part.mountpoint] = {
|
195 |
+
'total': usage.total / (1024**3), # GB
|
196 |
+
'used': usage.used / (1024**3), # GB
|
197 |
+
'free': usage.free / (1024**3), # GB
|
198 |
+
'percent': usage.percent
|
199 |
+
}
|
200 |
+
except PermissionError:
|
201 |
+
continue
|
202 |
+
|
203 |
+
sys_info = {
|
204 |
+
'system': platform.system(),
|
205 |
+
'version': platform.version(),
|
206 |
+
'platform': platform.platform(),
|
207 |
+
'processor': platform.processor(),
|
208 |
+
'hostname': platform.node(),
|
209 |
+
'python_version': platform.python_version(),
|
210 |
+
'uptime': time.time() - psutil.boot_time()
|
211 |
+
}
|
212 |
+
|
213 |
+
return {
|
214 |
+
'cpu': cpu_info,
|
215 |
+
'memory': memory_info,
|
216 |
+
'disk': disk_info,
|
217 |
+
'system': sys_info,
|
218 |
+
}
|
219 |
+
|
220 |
+
def generate_cpu_plot(self) -> plt.Figure:
|
221 |
+
"""Generate a plot of CPU usage over time
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
Matplotlib figure with CPU usage plot
|
225 |
+
"""
|
226 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
227 |
+
|
228 |
+
if not self.timestamps:
|
229 |
+
ax.set_title("No CPU data available yet")
|
230 |
+
return fig
|
231 |
+
|
232 |
+
x = [t.strftime('%H:%M:%S') for t in self.timestamps]
|
233 |
+
if len(x) > 10:
|
234 |
+
# Show fewer x-axis labels for readability
|
235 |
+
step = len(x) // 10
|
236 |
+
ax.set_xticks(range(0, len(x), step))
|
237 |
+
ax.set_xticklabels([x[i] for i in range(0, len(x), step)])
|
238 |
+
|
239 |
+
ax.plot(x, list(self.cpu_percent), 'b-', label='CPU Usage %')
|
240 |
+
|
241 |
+
if self.cpu_temp and len(self.cpu_temp) > 0:
|
242 |
+
# Plot temperature on a secondary y-axis if available
|
243 |
+
ax2 = ax.twinx()
|
244 |
+
ax2.plot(x[:len(self.cpu_temp)], list(self.cpu_temp), 'r-', label='CPU Temp °C')
|
245 |
+
ax2.set_ylabel('Temperature (°C)', color='r')
|
246 |
+
ax2.tick_params(axis='y', colors='r')
|
247 |
+
|
248 |
+
ax.set_title('CPU Usage Over Time')
|
249 |
+
ax.set_xlabel('Time')
|
250 |
+
ax.set_ylabel('Usage %')
|
251 |
+
ax.grid(True, alpha=0.3)
|
252 |
+
ax.set_ylim(0, 100)
|
253 |
+
|
254 |
+
# Add legend
|
255 |
+
lines, labels = ax.get_legend_handles_labels()
|
256 |
+
if hasattr(locals(), 'ax2'):
|
257 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
258 |
+
ax.legend(lines + lines2, labels + labels2, loc='upper left')
|
259 |
+
else:
|
260 |
+
ax.legend(loc='upper left')
|
261 |
+
|
262 |
+
plt.tight_layout()
|
263 |
+
return fig
|
264 |
+
|
265 |
+
def generate_memory_plot(self) -> plt.Figure:
|
266 |
+
"""Generate a plot of memory usage over time
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
Matplotlib figure with memory usage plot
|
270 |
+
"""
|
271 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
272 |
+
|
273 |
+
if not self.timestamps:
|
274 |
+
ax.set_title("No memory data available yet")
|
275 |
+
return fig
|
276 |
+
|
277 |
+
x = [t.strftime('%H:%M:%S') for t in self.timestamps]
|
278 |
+
if len(x) > 10:
|
279 |
+
# Show fewer x-axis labels for readability
|
280 |
+
step = len(x) // 10
|
281 |
+
ax.set_xticks(range(0, len(x), step))
|
282 |
+
ax.set_xticklabels([x[i] for i in range(0, len(x), step)])
|
283 |
+
|
284 |
+
ax.plot(x, list(self.memory_percent), 'g-', label='Memory Usage %')
|
285 |
+
|
286 |
+
# Add secondary y-axis for absolute memory values
|
287 |
+
ax2 = ax.twinx()
|
288 |
+
ax2.plot(x, list(self.memory_used), 'm--', label='Used (GB)')
|
289 |
+
ax2.plot(x, list(self.memory_available), 'c--', label='Available (GB)')
|
290 |
+
ax2.set_ylabel('Memory (GB)')
|
291 |
+
|
292 |
+
ax.set_title('Memory Usage Over Time')
|
293 |
+
ax.set_xlabel('Time')
|
294 |
+
ax.set_ylabel('Usage %')
|
295 |
+
ax.grid(True, alpha=0.3)
|
296 |
+
ax.set_ylim(0, 100)
|
297 |
+
|
298 |
+
# Add legend
|
299 |
+
lines, labels = ax.get_legend_handles_labels()
|
300 |
+
lines2, labels2 = ax2.get_legend_handles_labels()
|
301 |
+
ax.legend(lines + lines2, labels + labels2, loc='upper left')
|
302 |
+
|
303 |
+
plt.tight_layout()
|
304 |
+
return fig
|
305 |
+
|
306 |
+
def generate_per_core_plot(self) -> plt.Figure:
|
307 |
+
"""Generate a plot of per-core CPU usage
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
Matplotlib figure with per-core CPU usage
|
311 |
+
"""
|
312 |
+
num_cores = len(self.cpu_cores_percent)
|
313 |
+
if num_cores == 0:
|
314 |
+
# No data yet
|
315 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
316 |
+
ax.set_title("No per-core CPU data available yet")
|
317 |
+
return fig
|
318 |
+
|
319 |
+
# Determine grid layout based on number of cores
|
320 |
+
if num_cores <= 4:
|
321 |
+
rows, cols = 2, 2
|
322 |
+
elif num_cores <= 6:
|
323 |
+
rows, cols = 2, 3
|
324 |
+
elif num_cores <= 9:
|
325 |
+
rows, cols = 3, 3
|
326 |
+
elif num_cores <= 12:
|
327 |
+
rows, cols = 3, 4
|
328 |
+
else:
|
329 |
+
rows, cols = 4, 4
|
330 |
+
|
331 |
+
fig, axes = plt.subplots(rows, cols, figsize=(12, 8), sharex=True, sharey=True)
|
332 |
+
axes = axes.flatten()
|
333 |
+
|
334 |
+
x = [t.strftime('%H:%M:%S') for t in self.timestamps]
|
335 |
+
if len(x) > 5:
|
336 |
+
# Show fewer x-axis labels for readability
|
337 |
+
step = len(x) // 5
|
338 |
+
else:
|
339 |
+
step = 1
|
340 |
+
|
341 |
+
for i, (core_id, percentages) in enumerate(self.cpu_cores_percent.items()):
|
342 |
+
if i >= len(axes):
|
343 |
+
break
|
344 |
+
|
345 |
+
ax = axes[i]
|
346 |
+
ax.plot(x[:len(percentages)], list(percentages), 'b-')
|
347 |
+
ax.set_title(f'Core {core_id}')
|
348 |
+
ax.set_ylim(0, 100)
|
349 |
+
ax.grid(True, alpha=0.3)
|
350 |
+
|
351 |
+
# Add x-axis labels sparingly for readability
|
352 |
+
if i >= len(axes) - cols: # Only for bottom row
|
353 |
+
ax.set_xticks(range(0, len(x), step))
|
354 |
+
ax.set_xticklabels([x[i] for i in range(0, len(x), step)], rotation=45)
|
355 |
+
|
356 |
+
# Hide unused subplots
|
357 |
+
for i in range(num_cores, len(axes)):
|
358 |
+
axes[i].set_visible(False)
|
359 |
+
|
360 |
+
plt.tight_layout()
|
361 |
+
return fig
|
vms/tabs/__init__.py
CHANGED
@@ -6,6 +6,7 @@ from .import_tab import ImportTab
|
|
6 |
from .split_tab import SplitTab
|
7 |
from .caption_tab import CaptionTab
|
8 |
from .train_tab import TrainTab
|
|
|
9 |
from .manage_tab import ManageTab
|
10 |
|
11 |
__all__ = [
|
@@ -13,5 +14,6 @@ __all__ = [
|
|
13 |
'SplitTab',
|
14 |
'CaptionTab',
|
15 |
'TrainTab',
|
|
|
16 |
'ManageTab'
|
17 |
]
|
|
|
6 |
from .split_tab import SplitTab
|
7 |
from .caption_tab import CaptionTab
|
8 |
from .train_tab import TrainTab
|
9 |
+
from .monitor_tab import MonitorTab
|
10 |
from .manage_tab import ManageTab
|
11 |
|
12 |
__all__ = [
|
|
|
14 |
'SplitTab',
|
15 |
'CaptionTab',
|
16 |
'TrainTab',
|
17 |
+
'MonitorTab',
|
18 |
'ManageTab'
|
19 |
]
|
vms/tabs/import_tab/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Import tab for Video Model Studio.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .upload_tab import UploadTab
|
6 |
+
from .youtube_tab import YouTubeTab
|
7 |
+
from .hub_tab import HubTab
|
8 |
+
from .import_tab import ImportTab
|
9 |
+
|
10 |
+
__all__ = ['UploadTab', 'YouTubeTab', 'HubTab', 'ImportTab']
|
vms/tabs/import_tab/hub_tab.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Hugging Face Hub tab for Video Model Studio UI.
|
3 |
+
Handles browsing, searching, and importing datasets from the Hugging Face Hub.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import logging
|
8 |
+
import asyncio
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Dict, Any, List, Optional, Tuple
|
11 |
+
|
12 |
+
from ..base_tab import BaseTab
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
class HubTab(BaseTab):
|
17 |
+
"""Hub tab for importing datasets from Hugging Face Hub"""
|
18 |
+
|
19 |
+
def __init__(self, app_state):
|
20 |
+
super().__init__(app_state)
|
21 |
+
self.id = "hub_tab"
|
22 |
+
self.title = "Import from Hugging Face"
|
23 |
+
|
24 |
+
def create(self, parent=None) -> gr.Tab:
|
25 |
+
"""Create the Hub tab UI components"""
|
26 |
+
with gr.Tab(self.title, id=self.id) as tab:
|
27 |
+
with gr.Column():
|
28 |
+
with gr.Row():
|
29 |
+
gr.Markdown("## Import from Hub datasets")
|
30 |
+
|
31 |
+
with gr.Row():
|
32 |
+
gr.Markdown("Search for datasets with videos or WebDataset archives:")
|
33 |
+
|
34 |
+
with gr.Row():
|
35 |
+
self.components["dataset_search"] = gr.Textbox(
|
36 |
+
label="Search Hugging Face Datasets",
|
37 |
+
placeholder="Search for video datasets..."
|
38 |
+
)
|
39 |
+
|
40 |
+
with gr.Row():
|
41 |
+
self.components["dataset_search_btn"] = gr.Button("Search Datasets", variant="primary")
|
42 |
+
|
43 |
+
# Dataset browser results section
|
44 |
+
with gr.Row(visible=False) as dataset_results_row:
|
45 |
+
self.components["dataset_results_row"] = dataset_results_row
|
46 |
+
|
47 |
+
with gr.Column(scale=3):
|
48 |
+
self.components["dataset_results"] = gr.Dataframe(
|
49 |
+
headers=["id", "title", "downloads"],
|
50 |
+
interactive=False,
|
51 |
+
wrap=True,
|
52 |
+
row_count=10,
|
53 |
+
label="Dataset Results"
|
54 |
+
)
|
55 |
+
|
56 |
+
with gr.Column(scale=3):
|
57 |
+
# Dataset info and state
|
58 |
+
self.components["dataset_info"] = gr.Markdown("Select a dataset to see details")
|
59 |
+
self.components["dataset_id"] = gr.State(value=None)
|
60 |
+
self.components["file_type"] = gr.State(value=None)
|
61 |
+
|
62 |
+
# Files section that appears when a dataset is selected
|
63 |
+
with gr.Column(visible=False) as files_section:
|
64 |
+
self.components["files_section"] = files_section
|
65 |
+
|
66 |
+
gr.Markdown("## Files:")
|
67 |
+
|
68 |
+
# Video files row (appears if videos are present)
|
69 |
+
with gr.Row(visible=False) as video_files_row:
|
70 |
+
self.components["video_files_row"] = video_files_row
|
71 |
+
|
72 |
+
with gr.Column(scale=4):
|
73 |
+
self.components["video_count_text"] = gr.Markdown("Contains 0 video files")
|
74 |
+
|
75 |
+
with gr.Column(scale=1):
|
76 |
+
self.components["download_videos_btn"] = gr.Button("Download", variant="primary")
|
77 |
+
|
78 |
+
# WebDataset files row (appears if tar files are present)
|
79 |
+
with gr.Row(visible=False) as webdataset_files_row:
|
80 |
+
self.components["webdataset_files_row"] = webdataset_files_row
|
81 |
+
|
82 |
+
with gr.Column(scale=4):
|
83 |
+
self.components["webdataset_count_text"] = gr.Markdown("Contains 0 WebDataset (.tar) files")
|
84 |
+
|
85 |
+
with gr.Column(scale=1):
|
86 |
+
self.components["download_webdataset_btn"] = gr.Button("Download", variant="primary")
|
87 |
+
|
88 |
+
# Status and loading indicators
|
89 |
+
self.components["dataset_loading"] = gr.Markdown(visible=False)
|
90 |
+
|
91 |
+
return tab
|
92 |
+
|
93 |
+
def connect_events(self) -> None:
|
94 |
+
"""Connect event handlers to UI components"""
|
95 |
+
# Dataset search event
|
96 |
+
self.components["dataset_search_btn"].click(
|
97 |
+
fn=self.search_datasets,
|
98 |
+
inputs=[self.components["dataset_search"]],
|
99 |
+
outputs=[
|
100 |
+
self.components["dataset_results"],
|
101 |
+
self.components["dataset_results_row"]
|
102 |
+
]
|
103 |
+
)
|
104 |
+
|
105 |
+
# Dataset selection event - FIX HERE
|
106 |
+
self.components["dataset_results"].select(
|
107 |
+
fn=self.display_dataset_info,
|
108 |
+
outputs=[
|
109 |
+
self.components["dataset_info"],
|
110 |
+
self.components["dataset_id"],
|
111 |
+
self.components["files_section"],
|
112 |
+
self.components["video_files_row"],
|
113 |
+
self.components["video_count_text"],
|
114 |
+
self.components["webdataset_files_row"],
|
115 |
+
self.components["webdataset_count_text"]
|
116 |
+
]
|
117 |
+
)
|
118 |
+
|
119 |
+
# Download videos button
|
120 |
+
self.components["download_videos_btn"].click(
|
121 |
+
fn=self.set_file_type_and_return,
|
122 |
+
outputs=[self.components["file_type"]]
|
123 |
+
).then(
|
124 |
+
fn=self.download_file_group,
|
125 |
+
inputs=[
|
126 |
+
self.components["dataset_id"],
|
127 |
+
self.components["enable_automatic_video_split"],
|
128 |
+
self.components["file_type"]
|
129 |
+
],
|
130 |
+
outputs=[
|
131 |
+
self.components["dataset_loading"],
|
132 |
+
self.components["import_status"]
|
133 |
+
]
|
134 |
+
).success(
|
135 |
+
fn=self.app.tabs["import_tab"].on_import_success,
|
136 |
+
inputs=[
|
137 |
+
self.components["enable_automatic_video_split"],
|
138 |
+
self.components["enable_automatic_content_captioning"],
|
139 |
+
self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
140 |
+
],
|
141 |
+
outputs=[
|
142 |
+
self.app.tabs_component,
|
143 |
+
self.app.tabs["split_tab"].components["video_list"],
|
144 |
+
self.app.tabs["split_tab"].components["detect_status"]
|
145 |
+
]
|
146 |
+
)
|
147 |
+
|
148 |
+
# Download WebDataset button
|
149 |
+
self.components["download_webdataset_btn"].click(
|
150 |
+
fn=self.set_file_type_and_return_webdataset,
|
151 |
+
outputs=[self.components["file_type"]]
|
152 |
+
).then(
|
153 |
+
fn=self.download_file_group,
|
154 |
+
inputs=[
|
155 |
+
self.components["dataset_id"],
|
156 |
+
self.components["enable_automatic_video_split"],
|
157 |
+
self.components["file_type"]
|
158 |
+
],
|
159 |
+
outputs=[
|
160 |
+
self.components["dataset_loading"],
|
161 |
+
self.components["import_status"]
|
162 |
+
]
|
163 |
+
).success(
|
164 |
+
fn=self.app.tabs["import_tab"].on_import_success,
|
165 |
+
inputs=[
|
166 |
+
self.components["enable_automatic_video_split"],
|
167 |
+
self.components["enable_automatic_content_captioning"],
|
168 |
+
self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
169 |
+
],
|
170 |
+
outputs=[
|
171 |
+
self.app.tabs_component,
|
172 |
+
self.app.tabs["split_tab"].components["video_list"],
|
173 |
+
self.app.tabs["split_tab"].components["detect_status"]
|
174 |
+
]
|
175 |
+
)
|
176 |
+
|
177 |
+
def set_file_type_and_return(self):
|
178 |
+
"""Set file type to video and return it"""
|
179 |
+
return "video"
|
180 |
+
|
181 |
+
def set_file_type_and_return_webdataset(self):
|
182 |
+
"""Set file type to webdataset and return it"""
|
183 |
+
return "webdataset"
|
184 |
+
|
185 |
+
def search_datasets(self, query: str):
|
186 |
+
"""Search datasets on the Hub matching the query"""
|
187 |
+
try:
|
188 |
+
logger.info(f"Searching for datasets with query: '{query}'")
|
189 |
+
results = self.app.importer.search_datasets(query)
|
190 |
+
return results, gr.update(visible=True)
|
191 |
+
except Exception as e:
|
192 |
+
logger.error(f"Error searching datasets: {str(e)}", exc_info=True)
|
193 |
+
return [[f"Error: {str(e)}", "", ""]], gr.update(visible=True)
|
194 |
+
|
195 |
+
def display_dataset_info(self, evt: gr.SelectData):
|
196 |
+
"""Display detailed information about the selected dataset"""
|
197 |
+
try:
|
198 |
+
if not evt or not evt.value:
|
199 |
+
logger.warning("No dataset selected in display_dataset_info")
|
200 |
+
return (
|
201 |
+
"No dataset selected", # dataset_info
|
202 |
+
None, # dataset_id
|
203 |
+
gr.update(visible=False), # files_section
|
204 |
+
gr.update(visible=False), # video_files_row
|
205 |
+
"", # video_count_text
|
206 |
+
gr.update(visible=False), # webdataset_files_row
|
207 |
+
"" # webdataset_count_text
|
208 |
+
)
|
209 |
+
|
210 |
+
dataset_id = evt.value[0] if isinstance(evt.value, list) else evt.value
|
211 |
+
logger.info(f"Getting dataset info for: {dataset_id}")
|
212 |
+
|
213 |
+
# Use the importer service to get dataset info
|
214 |
+
info_text, file_counts, _ = self.app.importer.get_dataset_info(dataset_id)
|
215 |
+
|
216 |
+
# Get counts of each file type
|
217 |
+
video_count = file_counts.get("video", 0)
|
218 |
+
webdataset_count = file_counts.get("webdataset", 0)
|
219 |
+
|
220 |
+
# Return all the required outputs individually
|
221 |
+
return (
|
222 |
+
info_text, # dataset_info
|
223 |
+
dataset_id, # dataset_id
|
224 |
+
gr.update(visible=True), # files_section
|
225 |
+
gr.update(visible=video_count > 0), # video_files_row
|
226 |
+
f"Contains {video_count} video file{'s' if video_count != 1 else ''}", # video_count_text
|
227 |
+
gr.update(visible=webdataset_count > 0), # webdataset_files_row
|
228 |
+
f"Contains {webdataset_count} WebDataset (.tar) file{'s' if webdataset_count != 1 else ''}" # webdataset_count_text
|
229 |
+
)
|
230 |
+
except Exception as e:
|
231 |
+
logger.error(f"Error displaying dataset info: {str(e)}", exc_info=True)
|
232 |
+
return (
|
233 |
+
f"Error loading dataset information: {str(e)}", # dataset_info
|
234 |
+
None, # dataset_id
|
235 |
+
gr.update(visible=False), # files_section
|
236 |
+
gr.update(visible=False), # video_files_row
|
237 |
+
"", # video_count_text
|
238 |
+
gr.update(visible=False), # webdataset_files_row
|
239 |
+
"" # webdataset_count_text
|
240 |
+
)
|
241 |
+
|
242 |
+
def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str) -> Tuple[gr.update, str]:
|
243 |
+
"""Handle download of a group of files (videos or WebDatasets)"""
|
244 |
+
try:
|
245 |
+
if not dataset_id:
|
246 |
+
return gr.update(visible=False), "No dataset selected"
|
247 |
+
|
248 |
+
logger.info(f"Starting download of {file_type} files from dataset: {dataset_id}")
|
249 |
+
|
250 |
+
# Show loading indicator
|
251 |
+
loading_msg = gr.update(
|
252 |
+
value=f"## Downloading {file_type} files from {dataset_id}\n\nThis may take some time...",
|
253 |
+
visible=True
|
254 |
+
)
|
255 |
+
status_msg = f"Downloading {file_type} files from {dataset_id}..."
|
256 |
+
|
257 |
+
# Use the async version in a non-blocking way
|
258 |
+
asyncio.create_task(self._download_file_group_bg(dataset_id, file_type, enable_splitting))
|
259 |
+
|
260 |
+
return loading_msg, status_msg
|
261 |
+
|
262 |
+
except Exception as e:
|
263 |
+
error_msg = f"Error initiating download: {str(e)}"
|
264 |
+
logger.error(error_msg, exc_info=True)
|
265 |
+
return gr.update(visible=False), error_msg
|
266 |
+
|
267 |
+
async def _download_file_group_bg(self, dataset_id: str, file_type: str, enable_splitting: bool):
|
268 |
+
"""Background task for group file download"""
|
269 |
+
try:
|
270 |
+
# This will execute in the background
|
271 |
+
await self.app.importer.download_file_group(dataset_id, file_type, enable_splitting)
|
272 |
+
except Exception as e:
|
273 |
+
logger.error(f"Error in background file group download: {str(e)}", exc_info=True)
|
vms/tabs/{import_tab.py → import_tab/import_tab.py}
RENAMED
@@ -1,33 +1,42 @@
|
|
1 |
"""
|
2 |
-
|
3 |
"""
|
4 |
|
5 |
import gradio as gr
|
6 |
import logging
|
7 |
import asyncio
|
8 |
from pathlib import Path
|
9 |
-
from typing import Dict, Any, List, Optional
|
10 |
|
11 |
-
from
|
12 |
-
from
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
)
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
class ImportTab(BaseTab):
|
19 |
-
"""Import tab for uploading videos and images"""
|
20 |
|
21 |
def __init__(self, app_state):
|
22 |
super().__init__(app_state)
|
23 |
self.id = "import_tab"
|
24 |
self.title = "1️⃣ Import"
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def create(self, parent=None) -> gr.TabItem:
|
27 |
-
"""Create the Import tab UI components"""
|
28 |
with gr.TabItem(self.title, id=self.id) as tab:
|
29 |
with gr.Row():
|
30 |
-
gr.Markdown("##
|
31 |
|
32 |
with gr.Row():
|
33 |
self.components["enable_automatic_video_split"] = gr.Checkbox(
|
@@ -42,38 +51,19 @@ class ImportTab(BaseTab):
|
|
42 |
value=False,
|
43 |
visible=True,
|
44 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
gr.Markdown("You can upload either:")
|
52 |
-
gr.Markdown("- A single MP4 video file")
|
53 |
-
gr.Markdown("- A ZIP archive containing multiple videos/images and optional caption files")
|
54 |
-
gr.Markdown("- A WebDataset shard (.tar file)")
|
55 |
-
gr.Markdown("- A ZIP archive containing WebDataset shards (.tar files)")
|
56 |
-
|
57 |
-
with gr.Row():
|
58 |
-
self.components["files"] = gr.Files(
|
59 |
-
label="Upload Images, Videos, ZIP or WebDataset",
|
60 |
-
file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip", ".tar"],
|
61 |
-
type="filepath"
|
62 |
-
)
|
63 |
-
|
64 |
-
with gr.Column(scale=3):
|
65 |
-
with gr.Row():
|
66 |
-
with gr.Column():
|
67 |
-
gr.Markdown("## Import a YouTube video")
|
68 |
-
gr.Markdown("You can also use a YouTube video as reference, by pasting its URL here:")
|
69 |
-
|
70 |
-
with gr.Row():
|
71 |
-
self.components["youtube_url"] = gr.Textbox(
|
72 |
-
label="Import YouTube Video",
|
73 |
-
placeholder="https://www.youtube.com/watch?v=..."
|
74 |
-
)
|
75 |
-
with gr.Row():
|
76 |
-
self.components["youtube_download_btn"] = gr.Button("Download YouTube Video", variant="secondary")
|
77 |
with gr.Row():
|
78 |
self.components["import_status"] = gr.Textbox(label="Status", interactive=False)
|
79 |
|
@@ -81,47 +71,17 @@ class ImportTab(BaseTab):
|
|
81 |
|
82 |
def connect_events(self) -> None:
|
83 |
"""Connect event handlers to UI components"""
|
84 |
-
#
|
85 |
-
self.
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
],
|
96 |
-
outputs=[
|
97 |
-
self.app.tabs_component, # Main tabs component
|
98 |
-
self.app.tabs["split_tab"].components["video_list"],
|
99 |
-
self.app.tabs["split_tab"].components["detect_status"],
|
100 |
-
self.app.tabs["split_tab"].components["split_title"],
|
101 |
-
self.app.tabs["caption_tab"].components["caption_title"],
|
102 |
-
self.app.tabs["train_tab"].components["train_title"]
|
103 |
-
]
|
104 |
-
)
|
105 |
-
|
106 |
-
# YouTube download event
|
107 |
-
self.components["youtube_download_btn"].click(
|
108 |
-
fn=self.app.importer.download_youtube_video,
|
109 |
-
inputs=[self.components["youtube_url"]],
|
110 |
-
outputs=[self.components["import_status"]]
|
111 |
-
).success(
|
112 |
-
fn=self.on_import_success,
|
113 |
-
inputs=[
|
114 |
-
self.components["enable_automatic_video_split"],
|
115 |
-
self.components["enable_automatic_content_captioning"],
|
116 |
-
self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
117 |
-
],
|
118 |
-
outputs=[
|
119 |
-
self.app.tabs_component,
|
120 |
-
self.app.tabs["split_tab"].components["video_list"],
|
121 |
-
self.app.tabs["split_tab"].components["detect_status"]
|
122 |
-
]
|
123 |
-
)
|
124 |
-
|
125 |
async def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
126 |
"""Handle successful import of files"""
|
127 |
videos = self.app.tabs["split_tab"].list_unprocessed_videos()
|
|
|
1 |
"""
|
2 |
+
Parent import tab for Video Model Studio UI that contains sub-tabs
|
3 |
"""
|
4 |
|
5 |
import gradio as gr
|
6 |
import logging
|
7 |
import asyncio
|
8 |
from pathlib import Path
|
9 |
+
from typing import Dict, Any, List, Optional, Tuple
|
10 |
|
11 |
+
from ..base_tab import BaseTab
|
12 |
+
from .upload_tab import UploadTab
|
13 |
+
from .youtube_tab import YouTubeTab
|
14 |
+
from .hub_tab import HubTab
|
15 |
+
|
16 |
+
from vms.config import (
|
17 |
+
VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
|
18 |
+
STAGING_PATH
|
19 |
)
|
20 |
|
21 |
logger = logging.getLogger(__name__)
|
22 |
|
23 |
class ImportTab(BaseTab):
|
24 |
+
"""Import tab for uploading videos and images, and browsing datasets"""
|
25 |
|
26 |
def __init__(self, app_state):
|
27 |
super().__init__(app_state)
|
28 |
self.id = "import_tab"
|
29 |
self.title = "1️⃣ Import"
|
30 |
+
# Initialize sub-tabs
|
31 |
+
self.upload_tab = UploadTab(app_state)
|
32 |
+
self.youtube_tab = YouTubeTab(app_state)
|
33 |
+
self.hub_tab = HubTab(app_state)
|
34 |
|
35 |
def create(self, parent=None) -> gr.TabItem:
|
36 |
+
"""Create the Import tab UI components with three sub-tabs"""
|
37 |
with gr.TabItem(self.title, id=self.id) as tab:
|
38 |
with gr.Row():
|
39 |
+
gr.Markdown("## Import settings")
|
40 |
|
41 |
with gr.Row():
|
42 |
self.components["enable_automatic_video_split"] = gr.Checkbox(
|
|
|
51 |
value=False,
|
52 |
visible=True,
|
53 |
)
|
54 |
+
|
55 |
+
# Create tabs for different import methods
|
56 |
+
with gr.Tabs() as import_tabs:
|
57 |
+
# Create each sub-tab
|
58 |
+
self.upload_tab.create(import_tabs)
|
59 |
+
self.youtube_tab.create(import_tabs)
|
60 |
+
self.hub_tab.create(import_tabs)
|
61 |
|
62 |
+
# Store references to sub-tabs
|
63 |
+
self.components["upload_tab"] = self.upload_tab
|
64 |
+
self.components["youtube_tab"] = self.youtube_tab
|
65 |
+
self.components["hub_tab"] = self.hub_tab
|
66 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
with gr.Row():
|
68 |
self.components["import_status"] = gr.Textbox(label="Status", interactive=False)
|
69 |
|
|
|
71 |
|
72 |
def connect_events(self) -> None:
|
73 |
"""Connect event handlers to UI components"""
|
74 |
+
# Set shared components from parent tab to sub-tabs first
|
75 |
+
for subtab in [self.upload_tab, self.youtube_tab, self.hub_tab]:
|
76 |
+
subtab.components["import_status"] = self.components["import_status"]
|
77 |
+
subtab.components["enable_automatic_video_split"] = self.components["enable_automatic_video_split"]
|
78 |
+
subtab.components["enable_automatic_content_captioning"] = self.components["enable_automatic_content_captioning"]
|
79 |
+
|
80 |
+
# Then connect events for each sub-tab
|
81 |
+
self.upload_tab.connect_events()
|
82 |
+
self.youtube_tab.connect_events()
|
83 |
+
self.hub_tab.connect_events()
|
84 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
async def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
86 |
"""Handle successful import of files"""
|
87 |
videos = self.app.tabs["split_tab"].list_unprocessed_videos()
|
vms/tabs/import_tab/upload_tab.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Upload tab for Video Model Studio UI.
|
3 |
+
Handles manual file uploads for videos, images, and archives.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import logging
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Dict, Any, Optional
|
10 |
+
|
11 |
+
from ..base_tab import BaseTab
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
class UploadTab(BaseTab):
|
16 |
+
"""Upload tab for manual file uploads"""
|
17 |
+
|
18 |
+
def __init__(self, app_state):
|
19 |
+
super().__init__(app_state)
|
20 |
+
self.id = "upload_tab"
|
21 |
+
self.title = "Manual Upload"
|
22 |
+
|
23 |
+
def create(self, parent=None) -> gr.Tab:
|
24 |
+
"""Create the Upload tab UI components"""
|
25 |
+
with gr.Tab(self.title, id=self.id) as tab:
|
26 |
+
with gr.Column():
|
27 |
+
with gr.Row():
|
28 |
+
gr.Markdown("## Manual upload of video files")
|
29 |
+
|
30 |
+
with gr.Row():
|
31 |
+
with gr.Column():
|
32 |
+
with gr.Row():
|
33 |
+
gr.Markdown("You can upload either:")
|
34 |
+
with gr.Row():
|
35 |
+
gr.Markdown("- A single MP4 video file")
|
36 |
+
with gr.Row():
|
37 |
+
gr.Markdown("- A ZIP archive containing multiple videos/images and optional caption files")
|
38 |
+
with gr.Row():
|
39 |
+
gr.Markdown("- A WebDataset shard (.tar file)")
|
40 |
+
with gr.Row():
|
41 |
+
gr.Markdown("- A ZIP archive containing WebDataset shards (.tar files)")
|
42 |
+
with gr.Column():
|
43 |
+
with gr.Row():
|
44 |
+
self.components["files"] = gr.Files(
|
45 |
+
label="Upload Images, Videos, ZIP or WebDataset",
|
46 |
+
file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip", ".tar"],
|
47 |
+
type="filepath"
|
48 |
+
)
|
49 |
+
|
50 |
+
return tab
|
51 |
+
|
52 |
+
def connect_events(self) -> None:
|
53 |
+
"""Connect event handlers to UI components"""
|
54 |
+
# File upload event
|
55 |
+
self.components["files"].upload(
|
56 |
+
fn=lambda x: self.app.importer.process_uploaded_files(x),
|
57 |
+
inputs=[self.components["files"]],
|
58 |
+
outputs=[self.components["import_status"]] # This comes from parent tab
|
59 |
+
).success(
|
60 |
+
fn=self.app.tabs["import_tab"].update_titles_after_import,
|
61 |
+
inputs=[
|
62 |
+
self.components["enable_automatic_video_split"],
|
63 |
+
self.components["enable_automatic_content_captioning"],
|
64 |
+
self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
65 |
+
],
|
66 |
+
outputs=[
|
67 |
+
self.app.tabs_component, # Main tabs component
|
68 |
+
self.app.tabs["split_tab"].components["video_list"],
|
69 |
+
self.app.tabs["split_tab"].components["detect_status"],
|
70 |
+
self.app.tabs["split_tab"].components["split_title"],
|
71 |
+
self.app.tabs["caption_tab"].components["caption_title"],
|
72 |
+
self.app.tabs["train_tab"].components["train_title"]
|
73 |
+
]
|
74 |
+
)
|
vms/tabs/import_tab/youtube_tab.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
YouTube tab for Video Model Studio UI.
|
3 |
+
Handles downloading videos from YouTube URLs.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import logging
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Dict, Any, Optional
|
10 |
+
|
11 |
+
from ..base_tab import BaseTab
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
class YouTubeTab(BaseTab):
|
16 |
+
"""YouTube tab for downloading videos from YouTube"""
|
17 |
+
|
18 |
+
def __init__(self, app_state):
|
19 |
+
super().__init__(app_state)
|
20 |
+
self.id = "youtube_tab"
|
21 |
+
self.title = "Download from YouTube"
|
22 |
+
|
23 |
+
def create(self, parent=None) -> gr.Tab:
|
24 |
+
"""Create the YouTube tab UI components"""
|
25 |
+
with gr.Tab(self.title, id=self.id) as tab:
|
26 |
+
with gr.Column():
|
27 |
+
with gr.Row():
|
28 |
+
gr.Markdown("## Import a YouTube video")
|
29 |
+
|
30 |
+
with gr.Row():
|
31 |
+
with gr.Column():
|
32 |
+
with gr.Row():
|
33 |
+
gr.Markdown("You can use a YouTube video as reference, by pasting its URL here:")
|
34 |
+
with gr.Row():
|
35 |
+
gr.Markdown("Please be aware of the [know limitations](https://stackoverflow.com/questions/78160027/how-to-solve-http-error-400-bad-request-in-pytube) and [issues](https://stackoverflow.com/questions/79226520/pytube-throws-http-error-403-forbidden-since-a-few-days)")
|
36 |
+
|
37 |
+
with gr.Column():
|
38 |
+
self.components["youtube_url"] = gr.Textbox(
|
39 |
+
label="Import YouTube Video",
|
40 |
+
placeholder="https://www.youtube.com/watch?v=..."
|
41 |
+
)
|
42 |
+
|
43 |
+
with gr.Row():
|
44 |
+
self.components["youtube_download_btn"] = gr.Button("Download YouTube Video", variant="primary")
|
45 |
+
|
46 |
+
return tab
|
47 |
+
|
48 |
+
def connect_events(self) -> None:
|
49 |
+
"""Connect event handlers to UI components"""
|
50 |
+
# YouTube download event
|
51 |
+
self.components["youtube_download_btn"].click(
|
52 |
+
fn=self.app.importer.download_youtube_video,
|
53 |
+
inputs=[self.components["youtube_url"]],
|
54 |
+
outputs=[self.components["import_status"]] # This comes from parent tab
|
55 |
+
).success(
|
56 |
+
fn=self.app.tabs["import_tab"].on_import_success,
|
57 |
+
inputs=[
|
58 |
+
self.components["enable_automatic_video_split"],
|
59 |
+
self.components["enable_automatic_content_captioning"],
|
60 |
+
self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
|
61 |
+
],
|
62 |
+
outputs=[
|
63 |
+
self.app.tabs_component,
|
64 |
+
self.app.tabs["split_tab"].components["video_list"],
|
65 |
+
self.app.tabs["split_tab"].components["detect_status"]
|
66 |
+
]
|
67 |
+
)
|
vms/tabs/manage_tab.py
CHANGED
@@ -23,7 +23,7 @@ class ManageTab(BaseTab):
|
|
23 |
def __init__(self, app_state):
|
24 |
super().__init__(app_state)
|
25 |
self.id = "manage_tab"
|
26 |
-
self.title = "
|
27 |
|
28 |
def create(self, parent=None) -> gr.TabItem:
|
29 |
"""Create the Manage tab UI components"""
|
|
|
23 |
def __init__(self, app_state):
|
24 |
super().__init__(app_state)
|
25 |
self.id = "manage_tab"
|
26 |
+
self.title = "6️⃣ Manage"
|
27 |
|
28 |
def create(self, parent=None) -> gr.TabItem:
|
29 |
"""Create the Manage tab UI components"""
|
vms/tabs/monitor_tab.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
System monitoring tab for Video Model Studio UI.
|
3 |
+
Displays system metrics like CPU, memory usage, and temperatures.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import time
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
import os
|
11 |
+
import psutil
|
12 |
+
from typing import Dict, Any, List, Optional, Tuple
|
13 |
+
from datetime import datetime, timedelta
|
14 |
+
|
15 |
+
from .base_tab import BaseTab
|
16 |
+
from ..config import STORAGE_PATH
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
def get_folder_size(path):
|
21 |
+
"""Calculate the total size of a folder in bytes"""
|
22 |
+
total_size = 0
|
23 |
+
for dirpath, dirnames, filenames in os.walk(path):
|
24 |
+
for filename in filenames:
|
25 |
+
file_path = os.path.join(dirpath, filename)
|
26 |
+
if not os.path.islink(file_path): # Skip symlinks
|
27 |
+
total_size += os.path.getsize(file_path)
|
28 |
+
return total_size
|
29 |
+
|
30 |
+
def human_readable_size(size_bytes):
|
31 |
+
"""Convert a size in bytes to a human-readable string"""
|
32 |
+
if size_bytes == 0:
|
33 |
+
return "0 B"
|
34 |
+
size_names = ("B", "KB", "MB", "GB", "TB", "PB")
|
35 |
+
i = 0
|
36 |
+
while size_bytes >= 1024 and i < len(size_names) - 1:
|
37 |
+
size_bytes /= 1024
|
38 |
+
i += 1
|
39 |
+
return f"{size_bytes:.2f} {size_names[i]}"
|
40 |
+
|
41 |
+
class MonitorTab(BaseTab):
|
42 |
+
"""Monitor tab for system resource monitoring"""
|
43 |
+
|
44 |
+
def __init__(self, app_state):
|
45 |
+
super().__init__(app_state)
|
46 |
+
self.id = "monitor_tab"
|
47 |
+
self.title = "4️⃣ Monitor"
|
48 |
+
self.refresh_interval = 2 # Changed from 5 to 2 seconds
|
49 |
+
|
50 |
+
def create(self, parent=None) -> gr.TabItem:
|
51 |
+
"""Create the Monitor tab UI components"""
|
52 |
+
with gr.TabItem(self.title, id=self.id) as tab:
|
53 |
+
with gr.Row():
|
54 |
+
gr.Markdown("## System Monitoring")
|
55 |
+
|
56 |
+
# Current metrics
|
57 |
+
with gr.Row():
|
58 |
+
with gr.Column(scale=1):
|
59 |
+
self.components["current_metrics"] = gr.Markdown("Loading current metrics...")
|
60 |
+
|
61 |
+
# CPU and Memory charts in tabs
|
62 |
+
with gr.Tabs() as metrics_tabs:
|
63 |
+
with gr.Tab(label="CPU Usage") as cpu_tab:
|
64 |
+
self.components["cpu_plot"] = gr.Plot()
|
65 |
+
|
66 |
+
with gr.Tab(label="Memory Usage") as memory_tab:
|
67 |
+
self.components["memory_plot"] = gr.Plot()
|
68 |
+
|
69 |
+
with gr.Tab(label="Per-Core CPU") as per_core_tab:
|
70 |
+
self.components["per_core_plot"] = gr.Plot()
|
71 |
+
|
72 |
+
# System information summary in columns
|
73 |
+
with gr.Row():
|
74 |
+
with gr.Column(scale=1):
|
75 |
+
gr.Markdown("### System Information")
|
76 |
+
self.components["system_info"] = gr.Markdown("Loading system information...")
|
77 |
+
|
78 |
+
with gr.Column(scale=1):
|
79 |
+
gr.Markdown("### CPU Information")
|
80 |
+
self.components["cpu_info"] = gr.Markdown("Loading CPU information...")
|
81 |
+
|
82 |
+
with gr.Row():
|
83 |
+
with gr.Column(scale=1):
|
84 |
+
gr.Markdown("### Memory Information")
|
85 |
+
self.components["memory_info"] = gr.Markdown("Loading memory information...")
|
86 |
+
|
87 |
+
with gr.Column(scale=1):
|
88 |
+
gr.Markdown("### Storage Information")
|
89 |
+
self.components["storage_info"] = gr.Markdown("Loading storage information...")
|
90 |
+
|
91 |
+
# Toggle for enabling/disabling auto-refresh
|
92 |
+
with gr.Row():
|
93 |
+
self.components["auto_refresh"] = gr.Checkbox(
|
94 |
+
label=f"Auto refresh (every {self.refresh_interval} seconds)",
|
95 |
+
value=True,
|
96 |
+
info="Automatically refresh system metrics"
|
97 |
+
)
|
98 |
+
self.components["refresh_btn"] = gr.Button("Refresh Now")
|
99 |
+
|
100 |
+
# Timer for auto-refresh
|
101 |
+
self.components["refresh_timer"] = gr.Timer(
|
102 |
+
value=self.refresh_interval
|
103 |
+
)
|
104 |
+
|
105 |
+
return tab
|
106 |
+
|
107 |
+
def connect_events(self) -> None:
|
108 |
+
"""Connect event handlers to UI components"""
|
109 |
+
# Manual refresh button
|
110 |
+
self.components["refresh_btn"].click(
|
111 |
+
fn=self.refresh_all,
|
112 |
+
outputs=[
|
113 |
+
self.components["system_info"],
|
114 |
+
self.components["cpu_info"],
|
115 |
+
self.components["memory_info"],
|
116 |
+
self.components["storage_info"],
|
117 |
+
self.components["current_metrics"],
|
118 |
+
self.components["cpu_plot"],
|
119 |
+
self.components["memory_plot"],
|
120 |
+
self.components["per_core_plot"]
|
121 |
+
]
|
122 |
+
)
|
123 |
+
|
124 |
+
# Auto-refresh timer
|
125 |
+
self.components["refresh_timer"].tick(
|
126 |
+
fn=self.conditional_refresh,
|
127 |
+
inputs=[self.components["auto_refresh"]],
|
128 |
+
outputs=[
|
129 |
+
self.components["system_info"],
|
130 |
+
self.components["cpu_info"],
|
131 |
+
self.components["memory_info"],
|
132 |
+
self.components["storage_info"],
|
133 |
+
self.components["current_metrics"],
|
134 |
+
self.components["cpu_plot"],
|
135 |
+
self.components["memory_plot"],
|
136 |
+
self.components["per_core_plot"]
|
137 |
+
]
|
138 |
+
)
|
139 |
+
|
140 |
+
def on_enter(self):
|
141 |
+
"""Called when the tab is selected"""
|
142 |
+
# Start monitoring service if not already running
|
143 |
+
if not self.app.monitor.is_running:
|
144 |
+
self.app.monitor.start_monitoring()
|
145 |
+
|
146 |
+
# Trigger initial refresh
|
147 |
+
return self.refresh_all()
|
148 |
+
|
149 |
+
def conditional_refresh(self, auto_refresh: bool) -> Tuple:
|
150 |
+
"""Only refresh if auto-refresh is enabled
|
151 |
+
|
152 |
+
Args:
|
153 |
+
auto_refresh: Whether auto-refresh is enabled
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Updated components or unchanged components
|
157 |
+
"""
|
158 |
+
if auto_refresh:
|
159 |
+
return self.refresh_all()
|
160 |
+
|
161 |
+
# Return current values unchanged if auto-refresh is disabled
|
162 |
+
return (
|
163 |
+
self.components["system_info"].value,
|
164 |
+
self.components["cpu_info"].value,
|
165 |
+
self.components["memory_info"].value,
|
166 |
+
self.components["storage_info"].value,
|
167 |
+
self.components["current_metrics"].value,
|
168 |
+
self.components["cpu_plot"].value,
|
169 |
+
self.components["memory_plot"].value,
|
170 |
+
self.components["per_core_plot"].value
|
171 |
+
)
|
172 |
+
|
173 |
+
def refresh_all(self) -> Tuple:
|
174 |
+
"""Refresh all monitoring components
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
Updated values for all components
|
178 |
+
"""
|
179 |
+
try:
|
180 |
+
# Get system info
|
181 |
+
system_info = self.app.monitor.get_system_info()
|
182 |
+
|
183 |
+
# Split system info into separate components
|
184 |
+
system_info_html = self.format_system_info(system_info)
|
185 |
+
cpu_info_html = self.format_cpu_info(system_info)
|
186 |
+
memory_info_html = self.format_memory_info(system_info)
|
187 |
+
storage_info_html = self.format_storage_info()
|
188 |
+
|
189 |
+
# Get current metrics
|
190 |
+
current_metrics = self.app.monitor.get_current_metrics()
|
191 |
+
metrics_html = self.format_current_metrics(current_metrics)
|
192 |
+
|
193 |
+
# Generate plots
|
194 |
+
cpu_plot = self.app.monitor.generate_cpu_plot()
|
195 |
+
memory_plot = self.app.monitor.generate_memory_plot()
|
196 |
+
per_core_plot = self.app.monitor.generate_per_core_plot()
|
197 |
+
|
198 |
+
return (
|
199 |
+
system_info_html,
|
200 |
+
cpu_info_html,
|
201 |
+
memory_info_html,
|
202 |
+
storage_info_html,
|
203 |
+
metrics_html,
|
204 |
+
cpu_plot,
|
205 |
+
memory_plot,
|
206 |
+
per_core_plot
|
207 |
+
)
|
208 |
+
|
209 |
+
except Exception as e:
|
210 |
+
logger.error(f"Error refreshing monitoring data: {str(e)}", exc_info=True)
|
211 |
+
error_msg = f"Error retrieving data: {str(e)}"
|
212 |
+
return (
|
213 |
+
error_msg,
|
214 |
+
error_msg,
|
215 |
+
error_msg,
|
216 |
+
error_msg,
|
217 |
+
error_msg,
|
218 |
+
None, None, None
|
219 |
+
)
|
220 |
+
|
221 |
+
def format_system_info(self, system_info: Dict[str, Any]) -> str:
|
222 |
+
"""Format system information as HTML
|
223 |
+
|
224 |
+
Args:
|
225 |
+
system_info: System information dictionary
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
Formatted HTML string
|
229 |
+
"""
|
230 |
+
sys = system_info['system']
|
231 |
+
uptime_str = self.format_uptime(sys['uptime'])
|
232 |
+
|
233 |
+
html = f"""
|
234 |
+
**System:** {sys['system']} ({sys['platform']})
|
235 |
+
**Hostname:** {sys['hostname']}
|
236 |
+
**Uptime:** {uptime_str}
|
237 |
+
**Python Version:** {sys['python_version']}
|
238 |
+
"""
|
239 |
+
return html
|
240 |
+
|
241 |
+
def format_cpu_info(self, system_info: Dict[str, Any]) -> str:
|
242 |
+
"""Format CPU information as HTML
|
243 |
+
|
244 |
+
Args:
|
245 |
+
system_info: System information dictionary
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Formatted HTML string
|
249 |
+
"""
|
250 |
+
cpu = system_info['cpu']
|
251 |
+
sys = system_info['system']
|
252 |
+
|
253 |
+
# Format CPU frequency
|
254 |
+
cpu_freq = "N/A"
|
255 |
+
if cpu['current_frequency']:
|
256 |
+
cpu_freq = f"{cpu['current_frequency'] / 1000:.2f} GHz"
|
257 |
+
|
258 |
+
html = f"""
|
259 |
+
**Processor:** {sys['processor'] or cpu['architecture']}
|
260 |
+
**Physical Cores:** {cpu['cores_physical']}
|
261 |
+
**Logical Cores:** {cpu['cores_logical']}
|
262 |
+
**Current Frequency:** {cpu_freq}
|
263 |
+
"""
|
264 |
+
return html
|
265 |
+
|
266 |
+
def format_memory_info(self, system_info: Dict[str, Any]) -> str:
|
267 |
+
"""Format memory information as HTML
|
268 |
+
|
269 |
+
Args:
|
270 |
+
system_info: System information dictionary
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
Formatted HTML string
|
274 |
+
"""
|
275 |
+
memory = system_info['memory']
|
276 |
+
|
277 |
+
html = f"""
|
278 |
+
**Total Memory:** {memory['total']:.2f} GB
|
279 |
+
**Available Memory:** {memory['available']:.2f} GB
|
280 |
+
**Used Memory:** {memory['used']:.2f} GB
|
281 |
+
**Usage:** {memory['percent']}%
|
282 |
+
"""
|
283 |
+
return html
|
284 |
+
|
285 |
+
def format_storage_info(self) -> str:
|
286 |
+
"""Format storage information as HTML, focused on STORAGE_PATH
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
Formatted HTML string
|
290 |
+
"""
|
291 |
+
try:
|
292 |
+
# Get total size of STORAGE_PATH
|
293 |
+
total_size = get_folder_size(STORAGE_PATH)
|
294 |
+
total_size_readable = human_readable_size(total_size)
|
295 |
+
|
296 |
+
html = f"**Total Storage Used:** {total_size_readable}\n\n"
|
297 |
+
|
298 |
+
# Get size of each subfolder
|
299 |
+
html += "**Subfolder Sizes:**\n\n"
|
300 |
+
|
301 |
+
for subfolder in sorted(STORAGE_PATH.iterdir()):
|
302 |
+
if subfolder.is_dir():
|
303 |
+
folder_size = get_folder_size(subfolder)
|
304 |
+
folder_size_readable = human_readable_size(folder_size)
|
305 |
+
percentage = (folder_size / total_size * 100) if total_size > 0 else 0
|
306 |
+
|
307 |
+
folder_name = subfolder.name
|
308 |
+
html += f"* **{folder_name}**: {folder_size_readable} ({percentage:.1f}%)\n"
|
309 |
+
|
310 |
+
return html
|
311 |
+
|
312 |
+
except Exception as e:
|
313 |
+
logger.error(f"Error getting folder sizes: {str(e)}", exc_info=True)
|
314 |
+
return f"Error getting folder sizes: {str(e)}"
|
315 |
+
|
316 |
+
def format_current_metrics(self, metrics: Dict[str, Any]) -> str:
|
317 |
+
"""Format current metrics as HTML
|
318 |
+
|
319 |
+
Args:
|
320 |
+
metrics: Current metrics dictionary
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
Formatted HTML string
|
324 |
+
"""
|
325 |
+
timestamp = metrics['timestamp'].strftime('%Y-%m-%d %H:%M:%S')
|
326 |
+
|
327 |
+
# Style for CPU usage
|
328 |
+
cpu_style = "color: green;"
|
329 |
+
if metrics['cpu_percent'] > 90:
|
330 |
+
cpu_style = "color: red; font-weight: bold;"
|
331 |
+
elif metrics['cpu_percent'] > 70:
|
332 |
+
cpu_style = "color: orange;"
|
333 |
+
|
334 |
+
# Style for memory usage
|
335 |
+
mem_style = "color: green;"
|
336 |
+
if metrics['memory_percent'] > 90:
|
337 |
+
mem_style = "color: red; font-weight: bold;"
|
338 |
+
elif metrics['memory_percent'] > 70:
|
339 |
+
mem_style = "color: orange;"
|
340 |
+
|
341 |
+
# Temperature info
|
342 |
+
temp_html = ""
|
343 |
+
if metrics['cpu_temp'] is not None:
|
344 |
+
temp_style = "color: green;"
|
345 |
+
if metrics['cpu_temp'] > 80:
|
346 |
+
temp_style = "color: red; font-weight: bold;"
|
347 |
+
elif metrics['cpu_temp'] > 70:
|
348 |
+
temp_style = "color: orange;"
|
349 |
+
|
350 |
+
temp_html = f"""
|
351 |
+
**CPU Temperature:** <span style="{temp_style}">{metrics['cpu_temp']:.1f}°C</span>
|
352 |
+
"""
|
353 |
+
|
354 |
+
html = f"""
|
355 |
+
**CPU Usage:** <span style="{cpu_style}">{metrics['cpu_percent']:.1f}%</span>
|
356 |
+
**Memory Usage:** <span style="{mem_style}">{metrics['memory_percent']:.1f}% ({metrics['memory_used']:.2f}/{metrics['memory_available']:.2f} GB)</span>
|
357 |
+
{temp_html}
|
358 |
+
"""
|
359 |
+
|
360 |
+
# Add per-CPU core info
|
361 |
+
html += "\n"
|
362 |
+
|
363 |
+
per_cpu = metrics['per_cpu_percent']
|
364 |
+
cols = 4 # 4 cores per row
|
365 |
+
|
366 |
+
# Create a grid layout for cores
|
367 |
+
for i in range(0, len(per_cpu), cols):
|
368 |
+
row_cores = per_cpu[i:i+cols]
|
369 |
+
row_html = ""
|
370 |
+
|
371 |
+
for j, usage in enumerate(row_cores):
|
372 |
+
core_id = i + j
|
373 |
+
core_style = "color: green;"
|
374 |
+
if usage > 90:
|
375 |
+
core_style = "color: red; font-weight: bold;"
|
376 |
+
elif usage > 70:
|
377 |
+
core_style = "color: orange;"
|
378 |
+
|
379 |
+
row_html += f"**Core {core_id}:** <span style='{core_style}'>{usage:.1f}%</span> "
|
380 |
+
|
381 |
+
html += row_html + "\n"
|
382 |
+
|
383 |
+
return html
|
384 |
+
|
385 |
+
def format_uptime(self, seconds: float) -> str:
|
386 |
+
"""Format uptime in seconds to a human-readable string
|
387 |
+
|
388 |
+
Args:
|
389 |
+
seconds: Uptime in seconds
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
Formatted uptime string
|
393 |
+
"""
|
394 |
+
days = int(seconds // 86400)
|
395 |
+
seconds %= 86400
|
396 |
+
hours = int(seconds // 3600)
|
397 |
+
seconds %= 3600
|
398 |
+
minutes = int(seconds // 60)
|
399 |
+
|
400 |
+
parts = []
|
401 |
+
if days > 0:
|
402 |
+
parts.append(f"{days} day{'s' if days != 1 else ''}")
|
403 |
+
if hours > 0 or days > 0:
|
404 |
+
parts.append(f"{hours} hour{'s' if hours != 1 else ''}")
|
405 |
+
parts.append(f"{minutes} minute{'s' if minutes != 1 else ''}")
|
406 |
+
|
407 |
+
return ", ".join(parts)
|
vms/ui/video_trainer_ui.py
CHANGED
@@ -5,7 +5,7 @@ import logging
|
|
5 |
import asyncio
|
6 |
from typing import Any, Optional, Dict, List, Union, Tuple
|
7 |
|
8 |
-
from ..services import TrainingService, CaptioningService, SplittingService, ImportService
|
9 |
from ..config import (
|
10 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
|
11 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
@@ -28,7 +28,7 @@ from ..utils import (
|
|
28 |
format_media_title,
|
29 |
TrainingLogParser
|
30 |
)
|
31 |
-
from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
|
32 |
|
33 |
logger = logging.getLogger(__name__)
|
34 |
logger.setLevel(logging.INFO)
|
@@ -44,7 +44,11 @@ class VideoTrainerUI:
|
|
44 |
self.splitter = SplittingService()
|
45 |
self.importer = ImportService()
|
46 |
self.captioner = CaptioningService()
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
# Recovery status from any interrupted training
|
49 |
recovery_result = self.trainer.recover_interrupted_training()
|
50 |
# Add null check for recovery_result
|
@@ -81,6 +85,7 @@ class VideoTrainerUI:
|
|
81 |
self.tabs["split_tab"] = SplitTab(self)
|
82 |
self.tabs["caption_tab"] = CaptionTab(self)
|
83 |
self.tabs["train_tab"] = TrainTab(self)
|
|
|
84 |
self.tabs["manage_tab"] = ManageTab(self)
|
85 |
|
86 |
# Create tab UI components
|
|
|
5 |
import asyncio
|
6 |
from typing import Any, Optional, Dict, List, Union, Tuple
|
7 |
|
8 |
+
from ..services import TrainingService, CaptioningService, SplittingService, ImportService, MonitoringService
|
9 |
from ..config import (
|
10 |
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
|
11 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
|
|
28 |
format_media_title,
|
29 |
TrainingLogParser
|
30 |
)
|
31 |
+
from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, MonitorTab, ManageTab
|
32 |
|
33 |
logger = logging.getLogger(__name__)
|
34 |
logger.setLevel(logging.INFO)
|
|
|
44 |
self.splitter = SplittingService()
|
45 |
self.importer = ImportService()
|
46 |
self.captioner = CaptioningService()
|
47 |
+
self.monitor = MonitoringService()
|
48 |
+
|
49 |
+
# Start the monitoring service on app creation
|
50 |
+
self.monitor.start_monitoring()
|
51 |
+
|
52 |
# Recovery status from any interrupted training
|
53 |
recovery_result = self.trainer.recover_interrupted_training()
|
54 |
# Add null check for recovery_result
|
|
|
85 |
self.tabs["split_tab"] = SplitTab(self)
|
86 |
self.tabs["caption_tab"] = CaptionTab(self)
|
87 |
self.tabs["train_tab"] = TrainTab(self)
|
88 |
+
self.tabs["monitor_tab"] = MonitorTab(self)
|
89 |
self.tabs["manage_tab"] = ManageTab(self)
|
90 |
|
91 |
# Create tab UI components
|