jbilcke-hf HF Staff commited on
Commit
b613c3c
·
1 Parent(s): ecd5028

work on basic monitor (no gpu for now)

Browse files
README.md CHANGED
@@ -120,6 +120,33 @@ As this is not automatic, then click on "Restart" in the space dev mode UI widge
120
 
121
  I haven't tested it, but you can try to provided Dockerfile
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  ### Full installation in local
124
 
125
  the full installation requires:
@@ -127,7 +154,7 @@ the full installation requires:
127
  - CUDA 12
128
  - Python 3.10
129
 
130
- This is because of flash attention, which is defined in the `requirements.txt` using an URL to download a prebuilt wheel (python bindings for a native library)
131
 
132
  ```bash
133
  ./setup.sh
@@ -153,7 +180,7 @@ Here is how to do solution 3:
153
  Note: please make sure you properly define the environment variables for `STORAGE_PATH` (eg. `/data/`) and `HF_HOME` (eg. `/data/huggingface/`)
154
 
155
  ```bash
156
- python app.py
157
  ```
158
 
159
  ### Running locally
 
120
 
121
  I haven't tested it, but you can try to provided Dockerfile
122
 
123
+ ### Prerequisites
124
+
125
+ About Python:
126
+
127
+ I haven't tested Python 3.11 or 3.12, but I noticed some incompatibilities with Python 3.13 dependencies failing to install.
128
+
129
+ So I recommend you to install [pyenv](https://github.com/pyenv/pyenv) to switch between versions of Python.
130
+
131
+ If you are on macOS, you might already have some versions of Python installed, you can see them by typing:
132
+
133
+ ```bash
134
+ % python3.10 --version
135
+ Python 3.10.16
136
+ % python3.11 --version
137
+ Python 3.11.11
138
+ % python3.12 --version
139
+ Python 3.12.9
140
+ % python3.13 --version
141
+ Python 3.13.2
142
+ ```
143
+
144
+ Once pyenv is installed you can type:
145
+
146
+ ```bash
147
+ pyenv install 3.10.16
148
+ ```
149
+
150
  ### Full installation in local
151
 
152
  the full installation requires:
 
154
  - CUDA 12
155
  - Python 3.10
156
 
157
+ This is because of flash attention, which is defined in the `requirements.txt` using an URL to download a prebuilt wheel expecting this exact configuration (python bindings for a native library)
158
 
159
  ```bash
160
  ./setup.sh
 
180
  Note: please make sure you properly define the environment variables for `STORAGE_PATH` (eg. `/data/`) and `HF_HOME` (eg. `/data/huggingface/`)
181
 
182
  ```bash
183
+ python3.10 app.py
184
  ```
185
 
186
  ### Running locally
app.py CHANGED
@@ -14,6 +14,7 @@ from vms.config import (
14
  OUTPUT_PATH, ASK_USER_TO_DUPLICATE_SPACE,
15
  HF_API_TOKEN
16
  )
 
17
  from vms.ui.video_trainer_ui import VideoTrainerUI
18
 
19
  # Configure logging
@@ -37,7 +38,9 @@ To avoid overpaying for your space, you can configure the auto-sleep settings to
37
 
38
  # Create the main application UI
39
  ui = VideoTrainerUI()
40
- return ui.create_ui()
 
 
41
 
42
  def main():
43
  """Main entry point for the application"""
 
14
  OUTPUT_PATH, ASK_USER_TO_DUPLICATE_SPACE,
15
  HF_API_TOKEN
16
  )
17
+
18
  from vms.ui.video_trainer_ui import VideoTrainerUI
19
 
20
  # Configure logging
 
38
 
39
  # Create the main application UI
40
  ui = VideoTrainerUI()
41
+ app = ui.create_ui()
42
+
43
+ return app
44
 
45
  def main():
46
  """Main entry point for the application"""
requirements.txt CHANGED
@@ -2,6 +2,7 @@ numpy>=1.26.4
2
 
3
  # to quote a-r-r-o-w/finetrainers:
4
  # It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
 
5
  torch==2.5.1
6
  torchvision==0.20.1
7
  torchao==0.6.1
@@ -41,4 +42,7 @@ git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
41
 
42
  # for our frontend
43
  gradio==5.20.1
44
- gradio_toggle
 
 
 
 
2
 
3
  # to quote a-r-r-o-w/finetrainers:
4
  # It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
5
+ # on some system (Python 3.13+) those do not work:
6
  torch==2.5.1
7
  torchvision==0.20.1
8
  torchao==0.6.1
 
42
 
43
  # for our frontend
44
  gradio==5.20.1
45
+ gradio_toggle
46
+
47
+ # used for the monitor
48
+ matplotlib
requirements_without_flash_attention.txt CHANGED
@@ -2,11 +2,12 @@ numpy>=1.26.4
2
 
3
  # to quote a-r-r-o-w/finetrainers:
4
  # It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
 
 
5
  torch==2.5.1
6
  torchvision==0.20.1
7
  torchao==0.6.1
8
 
9
-
10
  huggingface_hub
11
  hf_transfer>=0.1.8
12
  diffusers @ git+https://github.com/huggingface/diffusers.git@main
@@ -40,4 +41,7 @@ git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
40
 
41
  # for our frontend
42
  gradio==5.20.1
43
- gradio_toggle
 
 
 
 
2
 
3
  # to quote a-r-r-o-w/finetrainers:
4
  # It is recommended to use Pytorch 2.5.1 or above for training. Previous versions can lead to completely black videos, OOM errors, or other issues and are not tested.
5
+
6
+ # on some system (Python 3.13+) those do not work:
7
  torch==2.5.1
8
  torchvision==0.20.1
9
  torchao==0.6.1
10
 
 
11
  huggingface_hub
12
  hf_transfer>=0.1.8
13
  diffusers @ git+https://github.com/huggingface/diffusers.git@main
 
41
 
42
  # for our frontend
43
  gradio==5.20.1
44
+ gradio_toggle
45
+
46
+ # used for the monitor
47
+ matplotlib
run.sh CHANGED
@@ -2,4 +2,12 @@
2
 
3
  source .venv/bin/activate
4
 
5
- USE_MOCK_CAPTIONING_MODEL=True python app.py
 
 
 
 
 
 
 
 
 
2
 
3
  source .venv/bin/activate
4
 
5
+ echo "if run.sh fails due to python being not found, edit run.sh to replace with another version of python"
6
+
7
+ # if you are on a mac, you can try to replace "python3.10" with:
8
+ # python3.10
9
+ # python3.11 (not tested)
10
+ # python3.12 (not tested)
11
+ # python3.13 (tested, fails to install)
12
+
13
+ USE_MOCK_CAPTIONING_MODEL=True python3.10 app.py
setup_no_captions.sh CHANGED
@@ -1,10 +1,18 @@
1
  #!/usr/bin/env bash
2
 
3
- python -m venv .venv
 
 
 
 
 
 
 
 
4
 
5
  source .venv/bin/activate
6
 
7
- python -m pip install -r requirements_without_flash_attention.txt
8
 
9
  # if you require flash attention, please install it manually for your operating system
10
 
 
1
  #!/usr/bin/env bash
2
 
3
+ echo "if install fails due to python being not found, edit setup_no_captions.sh to replace with another version of python"
4
+
5
+ # if you are on a mac, you can try to replace "python3.10" with:
6
+ # python3.10
7
+ # python3.11 (not tested)
8
+ # python3.12 (not tested)
9
+ # python3.13 (tested, fails to install)
10
+
11
+ python3.10 -m venv .venv
12
 
13
  source .venv/bin/activate
14
 
15
+ python3.10 -m pip install -r requirements_without_flash_attention.txt
16
 
17
  # if you require flash attention, please install it manually for your operating system
18
 
vms/services/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from .captioner import CaptioningProgress, CaptioningService
2
  from .importer import ImportService
 
3
  from .splitter import SplittingService
4
  from .trainer import TrainingService
5
 
@@ -7,6 +8,7 @@ __all__ = [
7
  'CaptioningProgress',
8
  'CaptioningService',
9
  'ImportService',
 
10
  'SplittingService',
11
  'TrainingService',
12
  ]
 
1
  from .captioner import CaptioningProgress, CaptioningService
2
  from .importer import ImportService
3
+ from .monitoring import MonitoringService
4
  from .splitter import SplittingService
5
  from .trainer import TrainingService
6
 
 
8
  'CaptioningProgress',
9
  'CaptioningService',
10
  'ImportService',
11
+ 'MonitoringService',
12
  'SplittingService',
13
  'TrainingService',
14
  ]
vms/services/importer/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Import module for Video Model Studio.
3
+ Handles file uploads, YouTube downloads, and Hugging Face Hub dataset integration.
4
+ """
5
+
6
+ from .import_service import ImportService
7
+ from .file_upload import FileUploadHandler
8
+ from .youtube import YouTubeDownloader
9
+ from .hub_dataset import HubDatasetBrowser
10
+
11
+ __all__ = ['ImportService', 'FileUploadHandler', 'YouTubeDownloader', 'HubDatasetBrowser']
vms/services/{importer.py → importer/file_upload.py} RENAMED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import os
2
  import shutil
3
  import zipfile
@@ -5,16 +10,18 @@ import tarfile
5
  import tempfile
6
  import gradio as gr
7
  from pathlib import Path
8
- from typing import List, Dict, Optional, Tuple
9
- from pytubefix import YouTube
10
  import logging
 
11
 
12
- from ..config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
13
- from ..utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
- class ImportService:
 
 
18
  def process_uploaded_files(self, file_paths: List[str]) -> str:
19
  """Process uploaded file (ZIP, TAR, MP4, or image)
20
 
@@ -24,11 +31,15 @@ class ImportService:
24
  Returns:
25
  Status message string
26
  """
 
 
 
 
27
  for file_path in file_paths:
28
  file_path = Path(file_path)
29
  try:
30
  original_name = file_path.name
31
- print("original_name = ", original_name)
32
 
33
  # Determine file type from name
34
  file_ext = file_path.suffix.lower()
@@ -42,9 +53,11 @@ class ImportService:
42
  elif is_image_file(file_path):
43
  return self.process_image_file(file_path, original_name)
44
  else:
 
45
  raise gr.Error(f"Unsupported file type: {file_ext}")
46
 
47
  except Exception as e:
 
48
  raise gr.Error(f"Error processing file: {str(e)}")
49
 
50
  def process_image_file(self, file_path: Path, original_name: str) -> str:
@@ -68,10 +81,13 @@ class ImportService:
68
  target_path = STAGING_PATH / f"{stem}___{counter}.{NORMALIZE_IMAGES_TO}"
69
  counter += 1
70
 
 
 
71
  # Convert to normalized format and remove black bars
72
  success = normalize_image(file_path, target_path)
73
 
74
  if not success:
 
75
  raise gr.Error(f"Failed to process image: {original_name}")
76
 
77
  # Handle caption
@@ -86,6 +102,7 @@ class ImportService:
86
  return f"Successfully stored image: {target_path.name}"
87
 
88
  except Exception as e:
 
89
  raise gr.Error(f"Error processing image file: {str(e)}")
90
 
91
  def process_zip_file(self, file_path: Path) -> str:
@@ -102,6 +119,8 @@ class ImportService:
102
  image_count = 0
103
  tar_count = 0
104
 
 
 
105
  # Create temporary directory
106
  with tempfile.TemporaryDirectory() as temp_dir:
107
  # Extract ZIP
@@ -121,6 +140,7 @@ class ImportService:
121
  try:
122
  # Check if it's a WebDataset tar file
123
  if file.lower().endswith('.tar'):
 
124
  # Process WebDataset shard
125
  vid_count, img_count = webdataset_handler.process_webdataset_shard(
126
  file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
@@ -136,6 +156,7 @@ class ImportService:
136
  target_path = VIDEOS_TO_SPLIT_PATH / f"{file_path.stem}___{counter}{file_path.suffix}"
137
  counter += 1
138
  shutil.copy2(file_path, target_path)
 
139
  video_count += 1
140
 
141
  elif is_image_file(file_path):
@@ -146,6 +167,7 @@ class ImportService:
146
  target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
147
  counter += 1
148
  if normalize_image(file_path, target_path):
 
149
  image_count += 1
150
 
151
  # Copy associated caption file if it exists
@@ -153,13 +175,15 @@ class ImportService:
153
  if txt_path.exists() and not file.lower().endswith('.tar'):
154
  if is_video_file(file_path):
155
  shutil.copy2(txt_path, target_path.with_suffix('.txt'))
 
156
  elif is_image_file(file_path):
157
  caption = txt_path.read_text()
158
  caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
159
  target_path.with_suffix('.txt').write_text(caption)
 
160
 
161
  except Exception as e:
162
- logger.error(f"Error processing {file_path.name}: {str(e)}")
163
  continue
164
 
165
  # Generate status message
@@ -172,13 +196,16 @@ class ImportService:
172
  parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
173
 
174
  if not parts:
 
175
  return "No supported media files found in ZIP"
176
 
177
  status = f"Successfully stored {', '.join(parts)}"
 
178
  gr.Info(status)
179
  return status
180
 
181
  except Exception as e:
 
182
  raise gr.Error(f"Error processing ZIP: {str(e)}")
183
 
184
  def process_tar_file(self, file_path: Path) -> str:
@@ -191,6 +218,7 @@ class ImportService:
191
  Status message string
192
  """
193
  try:
 
194
  video_count, image_count = webdataset_handler.process_webdataset_shard(
195
  file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
196
  )
@@ -203,13 +231,16 @@ class ImportService:
203
  parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
204
 
205
  if not parts:
 
206
  return "No supported media files found in WebDataset"
207
 
208
  status = f"Successfully extracted {' and '.join(parts)} from WebDataset"
 
209
  gr.Info(status)
210
  return status
211
 
212
  except Exception as e:
 
213
  raise gr.Error(f"Error processing WebDataset tar file: {str(e)}")
214
 
215
  def process_mp4_file(self, file_path: Path, original_name: str) -> str:
@@ -233,60 +264,15 @@ class ImportService:
233
  target_path = VIDEOS_TO_SPLIT_PATH / f"{stem}___{counter}.mp4"
234
  counter += 1
235
 
 
 
236
  # Copy the file to the target location
237
  shutil.copy2(file_path, target_path)
238
 
 
239
  gr.Info(f"Successfully stored video: {target_path.name}")
240
  return f"Successfully stored video: {target_path.name}"
241
 
242
  except Exception as e:
243
- raise gr.Error(f"Error processing video file: {str(e)}")
244
-
245
- def download_youtube_video(self, url: str, progress=None) -> Dict:
246
- """Download a video from YouTube
247
-
248
- Args:
249
- url: YouTube video URL
250
- progress: Optional Gradio progress indicator
251
-
252
- Returns:
253
- Dict with status message and error (if any)
254
- """
255
- try:
256
- # Extract video ID and create YouTube object
257
- yt = YouTube(url, on_progress_callback=lambda stream, chunk, bytes_remaining:
258
- progress((1 - bytes_remaining / stream.filesize), desc="Downloading...")
259
- if progress else None)
260
-
261
- video_id = yt.video_id
262
- output_path = VIDEOS_TO_SPLIT_PATH / f"{video_id}.mp4"
263
-
264
- # Download highest quality progressive MP4
265
- if progress:
266
- print("Getting video streams...")
267
- progress(0, desc="Getting video streams...")
268
- video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
269
-
270
- if not video:
271
- print("Could not find a compatible video format")
272
- gr.Error("Could not find a compatible video format")
273
- return "Could not find a compatible video format"
274
-
275
- # Download the video
276
- if progress:
277
- print("Starting YouTube video download...")
278
- progress(0, desc="Starting download...")
279
-
280
- video.download(output_path=str(VIDEOS_TO_SPLIT_PATH), filename=f"{video_id}.mp4")
281
-
282
- # Update UI
283
- if progress:
284
- print("YouTube video download complete!")
285
- gr.Info("YouTube video download complete!")
286
- progress(1, desc="Download complete!")
287
- return f"Successfully downloaded video: {yt.title}"
288
-
289
- except Exception as e:
290
- print(e)
291
- gr.Error(f"Error downloading video: {str(e)}")
292
- return f"Error downloading video: {str(e)}"
 
1
+ """
2
+ File upload handler for Video Model Studio.
3
+ Processes uploaded files including videos, images, ZIPs, and WebDataset archives.
4
+ """
5
+
6
  import os
7
  import shutil
8
  import zipfile
 
10
  import tempfile
11
  import gradio as gr
12
  from pathlib import Path
13
+ from typing import List, Dict, Optional, Tuple, Any, Union
 
14
  import logging
15
+ import traceback
16
 
17
+ from vms.config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
18
+ from vms.utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler
19
 
20
  logger = logging.getLogger(__name__)
21
 
22
+ class FileUploadHandler:
23
+ """Handles processing of uploaded files"""
24
+
25
  def process_uploaded_files(self, file_paths: List[str]) -> str:
26
  """Process uploaded file (ZIP, TAR, MP4, or image)
27
 
 
31
  Returns:
32
  Status message string
33
  """
34
+ if not file_paths or len(file_paths) == 0:
35
+ logger.warning("No files provided to process_uploaded_files")
36
+ return "No files provided"
37
+
38
  for file_path in file_paths:
39
  file_path = Path(file_path)
40
  try:
41
  original_name = file_path.name
42
+ logger.info(f"Processing uploaded file: {original_name}")
43
 
44
  # Determine file type from name
45
  file_ext = file_path.suffix.lower()
 
53
  elif is_image_file(file_path):
54
  return self.process_image_file(file_path, original_name)
55
  else:
56
+ logger.error(f"Unsupported file type: {file_ext}")
57
  raise gr.Error(f"Unsupported file type: {file_ext}")
58
 
59
  except Exception as e:
60
+ logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
61
  raise gr.Error(f"Error processing file: {str(e)}")
62
 
63
  def process_image_file(self, file_path: Path, original_name: str) -> str:
 
81
  target_path = STAGING_PATH / f"{stem}___{counter}.{NORMALIZE_IMAGES_TO}"
82
  counter += 1
83
 
84
+ logger.info(f"Processing image file: {original_name} -> {target_path}")
85
+
86
  # Convert to normalized format and remove black bars
87
  success = normalize_image(file_path, target_path)
88
 
89
  if not success:
90
+ logger.error(f"Failed to process image: {original_name}")
91
  raise gr.Error(f"Failed to process image: {original_name}")
92
 
93
  # Handle caption
 
102
  return f"Successfully stored image: {target_path.name}"
103
 
104
  except Exception as e:
105
+ logger.error(f"Error processing image file: {str(e)}", exc_info=True)
106
  raise gr.Error(f"Error processing image file: {str(e)}")
107
 
108
  def process_zip_file(self, file_path: Path) -> str:
 
119
  image_count = 0
120
  tar_count = 0
121
 
122
+ logger.info(f"Processing ZIP file: {file_path}")
123
+
124
  # Create temporary directory
125
  with tempfile.TemporaryDirectory() as temp_dir:
126
  # Extract ZIP
 
140
  try:
141
  # Check if it's a WebDataset tar file
142
  if file.lower().endswith('.tar'):
143
+ logger.info(f"Processing WebDataset archive from ZIP: {file}")
144
  # Process WebDataset shard
145
  vid_count, img_count = webdataset_handler.process_webdataset_shard(
146
  file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
 
156
  target_path = VIDEOS_TO_SPLIT_PATH / f"{file_path.stem}___{counter}{file_path.suffix}"
157
  counter += 1
158
  shutil.copy2(file_path, target_path)
159
+ logger.info(f"Extracted video from ZIP: {file} -> {target_path.name}")
160
  video_count += 1
161
 
162
  elif is_image_file(file_path):
 
167
  target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
168
  counter += 1
169
  if normalize_image(file_path, target_path):
170
+ logger.info(f"Extracted image from ZIP: {file} -> {target_path.name}")
171
  image_count += 1
172
 
173
  # Copy associated caption file if it exists
 
175
  if txt_path.exists() and not file.lower().endswith('.tar'):
176
  if is_video_file(file_path):
177
  shutil.copy2(txt_path, target_path.with_suffix('.txt'))
178
+ logger.info(f"Copied caption file for {file}")
179
  elif is_image_file(file_path):
180
  caption = txt_path.read_text()
181
  caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
182
  target_path.with_suffix('.txt').write_text(caption)
183
+ logger.info(f"Processed caption for {file}")
184
 
185
  except Exception as e:
186
+ logger.error(f"Error processing {file_path.name} from ZIP: {str(e)}", exc_info=True)
187
  continue
188
 
189
  # Generate status message
 
196
  parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
197
 
198
  if not parts:
199
+ logger.warning("No supported media files found in ZIP")
200
  return "No supported media files found in ZIP"
201
 
202
  status = f"Successfully stored {', '.join(parts)}"
203
+ logger.info(status)
204
  gr.Info(status)
205
  return status
206
 
207
  except Exception as e:
208
+ logger.error(f"Error processing ZIP: {str(e)}", exc_info=True)
209
  raise gr.Error(f"Error processing ZIP: {str(e)}")
210
 
211
  def process_tar_file(self, file_path: Path) -> str:
 
218
  Status message string
219
  """
220
  try:
221
+ logger.info(f"Processing WebDataset TAR file: {file_path}")
222
  video_count, image_count = webdataset_handler.process_webdataset_shard(
223
  file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
224
  )
 
231
  parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
232
 
233
  if not parts:
234
+ logger.warning("No supported media files found in WebDataset")
235
  return "No supported media files found in WebDataset"
236
 
237
  status = f"Successfully extracted {' and '.join(parts)} from WebDataset"
238
+ logger.info(status)
239
  gr.Info(status)
240
  return status
241
 
242
  except Exception as e:
243
+ logger.error(f"Error processing WebDataset tar file: {str(e)}", exc_info=True)
244
  raise gr.Error(f"Error processing WebDataset tar file: {str(e)}")
245
 
246
  def process_mp4_file(self, file_path: Path, original_name: str) -> str:
 
264
  target_path = VIDEOS_TO_SPLIT_PATH / f"{stem}___{counter}.mp4"
265
  counter += 1
266
 
267
+ logger.info(f"Processing video file: {original_name} -> {target_path}")
268
+
269
  # Copy the file to the target location
270
  shutil.copy2(file_path, target_path)
271
 
272
+ logger.info(f"Successfully stored video: {target_path.name}")
273
  gr.Info(f"Successfully stored video: {target_path.name}")
274
  return f"Successfully stored video: {target_path.name}"
275
 
276
  except Exception as e:
277
+ logger.error(f"Error processing video file: {str(e)}", exc_info=True)
278
+ raise gr.Error(f"Error processing video file: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vms/services/importer/hub_dataset.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Hub dataset browser for Video Model Studio.
3
+ Handles searching, viewing, and downloading datasets from the Hub.
4
+ """
5
+
6
+ import os
7
+ import shutil
8
+ import tempfile
9
+ import asyncio
10
+ import logging
11
+ import gradio as gr
12
+ from pathlib import Path
13
+ from typing import List, Dict, Optional, Tuple, Any, Union
14
+
15
+ from huggingface_hub import (
16
+ HfApi,
17
+ hf_hub_download,
18
+ snapshot_download,
19
+ list_datasets
20
+ )
21
+
22
+ from vms.config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
23
+ from vms.utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ class HubDatasetBrowser:
28
+ """Handles interactions with Hugging Face Hub datasets"""
29
+
30
+ def __init__(self, hf_api: HfApi):
31
+ """Initialize with HfApi instance
32
+
33
+ Args:
34
+ hf_api: Hugging Face Hub API instance
35
+ """
36
+ self.hf_api = hf_api
37
+
38
+ def search_datasets(self, query: str) -> List[List[str]]:
39
+ """Search for datasets on the Hugging Face Hub
40
+
41
+ Args:
42
+ query: Search query string
43
+
44
+ Returns:
45
+ List of datasets matching the query [id, title, downloads]
46
+ """
47
+ try:
48
+ # Start with some filters to find video-related datasets
49
+ search_terms = query.strip() if query and query.strip() else "video"
50
+ logger.info(f"Searching datasets with query: '{search_terms}'")
51
+
52
+ # Fetch datasets that match the search
53
+ datasets = list(self.hf_api.list_datasets(
54
+ search=search_terms,
55
+ limit=50
56
+ ))
57
+
58
+ # Format results for display
59
+ results = []
60
+ for ds in datasets:
61
+ # Extract relevant information
62
+ dataset_id = ds.id
63
+
64
+ # Safely get the title with fallbacks
65
+ card_data = getattr(ds, "card_data", None)
66
+ title = ""
67
+
68
+ if card_data is not None and isinstance(card_data, dict):
69
+ title = card_data.get("name", "")
70
+
71
+ if not title:
72
+ # Use the last part of the repo ID as a fallback
73
+ title = dataset_id.split("/")[-1]
74
+
75
+ # Safely get downloads
76
+ downloads = getattr(ds, "downloads", 0)
77
+ if downloads is None:
78
+ downloads = 0
79
+
80
+ results.append([dataset_id, title, downloads])
81
+
82
+ # Sort by downloads (most downloaded first)
83
+ results.sort(key=lambda x: x[2] if x[2] is not None else 0, reverse=True)
84
+
85
+ logger.info(f"Found {len(results)} datasets matching '{search_terms}'")
86
+ return results
87
+
88
+ except Exception as e:
89
+ logger.error(f"Error searching datasets: {str(e)}", exc_info=True)
90
+ return [[f"Error: {str(e)}", "", ""]]
91
+
92
+ def get_dataset_info(self, dataset_id: str) -> Tuple[str, Dict[str, int], Dict[str, List[str]]]:
93
+ """Get detailed information about a dataset
94
+
95
+ Args:
96
+ dataset_id: The dataset ID to get information for
97
+
98
+ Returns:
99
+ Tuple of (markdown_info, file_counts, file_groups)
100
+ - markdown_info: Markdown formatted string with dataset information
101
+ - file_counts: Dictionary with counts of each file type
102
+ - file_groups: Dictionary with lists of filenames grouped by type
103
+ """
104
+ try:
105
+ if not dataset_id:
106
+ logger.warning("No dataset ID provided to get_dataset_info")
107
+ return "No dataset selected", {}, {}
108
+
109
+ logger.info(f"Getting info for dataset: {dataset_id}")
110
+
111
+ # Get detailed information about the dataset
112
+ dataset_info = self.hf_api.dataset_info(dataset_id)
113
+
114
+ # Format the information for display
115
+ info_text = f"## {dataset_info.id}\n\n"
116
+
117
+ # Add description if available (with safer access)
118
+ card_data = getattr(dataset_info, "card_data", None)
119
+ description = ""
120
+
121
+ if card_data is not None and isinstance(card_data, dict):
122
+ description = card_data.get("description", "")
123
+
124
+ if description:
125
+ info_text += f"{description[:500]}{'...' if len(description) > 500 else ''}\n\n"
126
+
127
+ # Add basic stats (with safer access)
128
+ downloads = getattr(dataset_info, 'downloads', None)
129
+ info_text += f"**Downloads:** {downloads if downloads is not None else 'N/A'}\n"
130
+
131
+ last_modified = getattr(dataset_info, 'last_modified', None)
132
+ info_text += f"**Last modified:** {last_modified if last_modified is not None else 'N/A'}\n"
133
+
134
+ # Show tags if available (with safer access)
135
+ tags = getattr(dataset_info, "tags", None) or []
136
+ if tags:
137
+ info_text += f"**Tags:** {', '.join(tags[:10])}\n\n"
138
+
139
+ # Group files by type
140
+ file_groups = {
141
+ "video": [],
142
+ "webdataset": []
143
+ }
144
+
145
+ siblings = getattr(dataset_info, "siblings", None) or []
146
+
147
+ # Extract files by type
148
+ for s in siblings:
149
+ if not hasattr(s, 'rfilename'):
150
+ continue
151
+
152
+ filename = s.rfilename
153
+ if filename.lower().endswith((".mp4", ".webm")):
154
+ file_groups["video"].append(filename)
155
+ elif filename.lower().endswith(".tar"):
156
+ file_groups["webdataset"].append(filename)
157
+
158
+ # Create file counts dictionary
159
+ file_counts = {
160
+ "video": len(file_groups["video"]),
161
+ "webdataset": len(file_groups["webdataset"])
162
+ }
163
+
164
+ logger.info(f"Successfully retrieved info for dataset: {dataset_id}")
165
+ return info_text, file_counts, file_groups
166
+
167
+ except Exception as e:
168
+ logger.error(f"Error getting dataset info: {str(e)}", exc_info=True)
169
+ return f"Error loading dataset information: {str(e)}", {}, {}
170
+
171
+ async def download_file_group(self, dataset_id: str, file_type: str, enable_splitting: bool = True) -> str:
172
+ """Download all files of a specific type from the dataset
173
+
174
+ Args:
175
+ dataset_id: The dataset ID
176
+ file_type: Either "video" or "webdataset"
177
+ enable_splitting: Whether to enable automatic video splitting
178
+
179
+ Returns:
180
+ Status message
181
+ """
182
+ try:
183
+ # Get dataset info to retrieve file list
184
+ _, _, file_groups = self.get_dataset_info(dataset_id)
185
+
186
+ # Get the list of files for the specified type
187
+ files = file_groups.get(file_type, [])
188
+
189
+ if not files:
190
+ return f"No {file_type} files found in the dataset"
191
+
192
+ logger.info(f"Downloading {len(files)} {file_type} files from dataset {dataset_id}")
193
+
194
+ # Track counts for status message
195
+ video_count = 0
196
+ image_count = 0
197
+
198
+ # Create a temporary directory for downloads
199
+ with tempfile.TemporaryDirectory() as temp_dir:
200
+ temp_path = Path(temp_dir)
201
+
202
+ # Process all files of the requested type
203
+ for filename in files:
204
+ try:
205
+ # Download the file
206
+ file_path = hf_hub_download(
207
+ repo_id=dataset_id,
208
+ filename=filename,
209
+ repo_type="dataset",
210
+ local_dir=temp_path
211
+ )
212
+
213
+ file_path = Path(file_path)
214
+ logger.info(f"Downloaded file to {file_path}")
215
+
216
+ # Process based on file type
217
+ if file_type == "video":
218
+ # Choose target directory based on auto-splitting setting
219
+ target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
220
+ target_path = target_dir / file_path.name
221
+
222
+ # Make sure filename is unique
223
+ counter = 1
224
+ while target_path.exists():
225
+ stem = Path(file_path.name).stem
226
+ if "___" in stem:
227
+ base_stem = stem.split("___")[0]
228
+ else:
229
+ base_stem = stem
230
+ target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}"
231
+ counter += 1
232
+
233
+ # Copy the video file
234
+ shutil.copy2(file_path, target_path)
235
+ logger.info(f"Processed video: {file_path.name} -> {target_path.name}")
236
+
237
+ # Try to download caption if it exists
238
+ try:
239
+ txt_filename = f"{Path(filename).stem}.txt"
240
+ for possible_path in [
241
+ Path(filename).with_suffix('.txt').as_posix(),
242
+ (Path(filename).parent / txt_filename).as_posix(),
243
+ ]:
244
+ try:
245
+ txt_path = hf_hub_download(
246
+ repo_id=dataset_id,
247
+ filename=possible_path,
248
+ repo_type="dataset",
249
+ local_dir=temp_path
250
+ )
251
+ shutil.copy2(txt_path, target_path.with_suffix('.txt'))
252
+ logger.info(f"Copied caption for {file_path.name}")
253
+ break
254
+ except Exception:
255
+ # Caption file doesn't exist at this path, try next
256
+ pass
257
+ except Exception as e:
258
+ logger.warning(f"Error trying to download caption: {e}")
259
+
260
+ video_count += 1
261
+
262
+ elif file_type == "webdataset":
263
+ # Process the WebDataset archive
264
+ try:
265
+ logger.info(f"Processing WebDataset file: {file_path}")
266
+ vid_count, img_count = webdataset_handler.process_webdataset_shard(
267
+ file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
268
+ )
269
+ video_count += vid_count
270
+ image_count += img_count
271
+ except Exception as e:
272
+ logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
273
+
274
+ except Exception as e:
275
+ logger.warning(f"Error processing file {filename}: {e}")
276
+
277
+ # Generate status message
278
+ if file_type == "video":
279
+ return f"Successfully imported {video_count} videos from dataset {dataset_id}"
280
+ elif file_type == "webdataset":
281
+ parts = []
282
+ if video_count > 0:
283
+ parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
284
+ if image_count > 0:
285
+ parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
286
+
287
+ if parts:
288
+ return f"Successfully imported {' and '.join(parts)} from WebDataset archives"
289
+ else:
290
+ return f"No media was found in the WebDataset archives"
291
+
292
+ return f"Unknown file type: {file_type}"
293
+
294
+ except Exception as e:
295
+ error_msg = f"Error downloading {file_type} files: {str(e)}"
296
+ logger.error(error_msg, exc_info=True)
297
+ return error_msg
298
+
299
+ async def download_dataset(self, dataset_id: str, enable_splitting: bool = True) -> Tuple[str, str]:
300
+ """Download a dataset and process its video/image content
301
+
302
+ Args:
303
+ dataset_id: The dataset ID to download
304
+ enable_splitting: Whether to enable automatic video splitting
305
+
306
+ Returns:
307
+ Tuple of (loading_msg, status_msg)
308
+ """
309
+ if not dataset_id:
310
+ logger.warning("No dataset ID provided for download")
311
+ return "No dataset selected", "Please select a dataset first"
312
+
313
+ try:
314
+ logger.info(f"Starting download of dataset: {dataset_id}")
315
+ loading_msg = f"## Downloading dataset: {dataset_id}\n\nThis may take some time depending on the dataset size..."
316
+ status_msg = f"Downloading dataset: {dataset_id}..."
317
+
318
+ # Get dataset info to check for available files
319
+ dataset_info = self.hf_api.dataset_info(dataset_id)
320
+
321
+ # Check if there are video files or WebDataset files
322
+ video_files = []
323
+ tar_files = []
324
+
325
+ siblings = getattr(dataset_info, "siblings", None) or []
326
+ if siblings:
327
+ video_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith((".mp4", ".webm"))]
328
+ tar_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith(".tar")]
329
+
330
+ # Create a temporary directory for downloads
331
+ with tempfile.TemporaryDirectory() as temp_dir:
332
+ temp_path = Path(temp_dir)
333
+
334
+ # If we have video files, download them individually
335
+ if video_files:
336
+ loading_msg = f"{loading_msg}\n\nDownloading {len(video_files)} video files..."
337
+ logger.info(f"Downloading {len(video_files)} video files from {dataset_id}")
338
+
339
+ for i, video_file in enumerate(video_files):
340
+ # Download the video file
341
+ try:
342
+ file_path = hf_hub_download(
343
+ repo_id=dataset_id,
344
+ filename=video_file,
345
+ repo_type="dataset",
346
+ local_dir=temp_path
347
+ )
348
+
349
+ # Look for associated caption file
350
+ txt_filename = f"{Path(video_file).stem}.txt"
351
+ txt_path = None
352
+ for possible_path in [
353
+ Path(video_file).with_suffix('.txt').as_posix(),
354
+ (Path(video_file).parent / txt_filename).as_posix(),
355
+ ]:
356
+ try:
357
+ txt_path = hf_hub_download(
358
+ repo_id=dataset_id,
359
+ filename=possible_path,
360
+ repo_type="dataset",
361
+ local_dir=temp_path
362
+ )
363
+ logger.info(f"Found caption file for {video_file}: {possible_path}")
364
+ break
365
+ except Exception as e:
366
+ # Caption file doesn't exist at this path, try next
367
+ logger.debug(f"No caption at {possible_path}: {str(e)}")
368
+ pass
369
+
370
+ status_msg = f"Downloaded video {i+1}/{len(video_files)} from {dataset_id}"
371
+ logger.info(status_msg)
372
+ except Exception as e:
373
+ logger.warning(f"Error downloading {video_file}: {e}")
374
+
375
+ # If we have tar files, download them
376
+ if tar_files:
377
+ loading_msg = f"{loading_msg}\n\nDownloading {len(tar_files)} WebDataset files..."
378
+ logger.info(f"Downloading {len(tar_files)} WebDataset files from {dataset_id}")
379
+
380
+ for i, tar_file in enumerate(tar_files):
381
+ try:
382
+ file_path = hf_hub_download(
383
+ repo_id=dataset_id,
384
+ filename=tar_file,
385
+ repo_type="dataset",
386
+ local_dir=temp_path
387
+ )
388
+ status_msg = f"Downloaded WebDataset {i+1}/{len(tar_files)} from {dataset_id}"
389
+ logger.info(status_msg)
390
+ except Exception as e:
391
+ logger.warning(f"Error downloading {tar_file}: {e}")
392
+
393
+ # If no specific files were found, try downloading the entire repo
394
+ if not video_files and not tar_files:
395
+ loading_msg = f"{loading_msg}\n\nDownloading entire dataset repository..."
396
+ logger.info(f"No specific media files found, downloading entire repository for {dataset_id}")
397
+
398
+ try:
399
+ snapshot_download(
400
+ repo_id=dataset_id,
401
+ repo_type="dataset",
402
+ local_dir=temp_path
403
+ )
404
+ status_msg = f"Downloaded entire repository for {dataset_id}"
405
+ logger.info(status_msg)
406
+ except Exception as e:
407
+ logger.error(f"Error downloading dataset snapshot: {e}", exc_info=True)
408
+ return loading_msg, f"Error downloading dataset: {str(e)}"
409
+
410
+ # Process the downloaded files
411
+ loading_msg = f"{loading_msg}\n\nProcessing downloaded files..."
412
+ logger.info(f"Processing downloaded files from {dataset_id}")
413
+
414
+ # Count imported files
415
+ video_count = 0
416
+ image_count = 0
417
+ tar_count = 0
418
+
419
+ # Process function for the event loop
420
+ async def process_files():
421
+ nonlocal video_count, image_count, tar_count
422
+
423
+ # Process all files in the temp directory
424
+ for root, _, files in os.walk(temp_path):
425
+ for file in files:
426
+ file_path = Path(root) / file
427
+
428
+ # Process videos
429
+ if file.lower().endswith((".mp4", ".webm")):
430
+ # Choose target path based on auto-splitting setting
431
+ target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
432
+ target_path = target_dir / file_path.name
433
+
434
+ # Make sure filename is unique
435
+ counter = 1
436
+ while target_path.exists():
437
+ stem = Path(file_path.name).stem
438
+ if "___" in stem:
439
+ base_stem = stem.split("___")[0]
440
+ else:
441
+ base_stem = stem
442
+ target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}"
443
+ counter += 1
444
+
445
+ # Copy the video file
446
+ shutil.copy2(file_path, target_path)
447
+ logger.info(f"Processed video from dataset: {file_path.name} -> {target_path.name}")
448
+
449
+ # Copy associated caption file if it exists
450
+ txt_path = file_path.with_suffix('.txt')
451
+ if txt_path.exists():
452
+ shutil.copy2(txt_path, target_path.with_suffix('.txt'))
453
+ logger.info(f"Copied caption for {file_path.name}")
454
+
455
+ video_count += 1
456
+
457
+ # Process images
458
+ elif is_image_file(file_path):
459
+ target_path = STAGING_PATH / f"{file_path.stem}.{NORMALIZE_IMAGES_TO}"
460
+
461
+ counter = 1
462
+ while target_path.exists():
463
+ target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
464
+ counter += 1
465
+
466
+ if normalize_image(file_path, target_path):
467
+ logger.info(f"Processed image from dataset: {file_path.name} -> {target_path.name}")
468
+
469
+ # Copy caption if available
470
+ txt_path = file_path.with_suffix('.txt')
471
+ if txt_path.exists():
472
+ caption = txt_path.read_text()
473
+ caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
474
+ target_path.with_suffix('.txt').write_text(caption)
475
+ logger.info(f"Processed caption for {file_path.name}")
476
+
477
+ image_count += 1
478
+
479
+ # Process WebDataset files
480
+ elif file.lower().endswith(".tar"):
481
+ # Process the WebDataset archive
482
+ try:
483
+ logger.info(f"Processing WebDataset file from dataset: {file}")
484
+ vid_count, img_count = webdataset_handler.process_webdataset_shard(
485
+ file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
486
+ )
487
+ tar_count += 1
488
+ video_count += vid_count
489
+ image_count += img_count
490
+ logger.info(f"Extracted {vid_count} videos and {img_count} images from {file}")
491
+ except Exception as e:
492
+ logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
493
+
494
+ # Run the processing asynchronously
495
+ await process_files()
496
+
497
+ # Generate final status message
498
+ parts = []
499
+ if video_count > 0:
500
+ parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
501
+ if image_count > 0:
502
+ parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
503
+ if tar_count > 0:
504
+ parts.append(f"{tar_count} WebDataset archive{'s' if tar_count != 1 else ''}")
505
+
506
+ if parts:
507
+ status = f"Successfully imported {', '.join(parts)} from dataset {dataset_id}"
508
+ loading_msg = f"{loading_msg}\n\n✅ Success! {status}"
509
+ logger.info(status)
510
+ else:
511
+ status = f"No supported media files found in dataset {dataset_id}"
512
+ loading_msg = f"{loading_msg}\n\n⚠️ {status}"
513
+ logger.warning(status)
514
+
515
+ gr.Info(status)
516
+ return loading_msg, status
517
+
518
+ except Exception as e:
519
+ error_msg = f"Error downloading dataset {dataset_id}: {str(e)}"
520
+ logger.error(error_msg, exc_info=True)
521
+ return f"Error: {error_msg}", error_msg
vms/services/importer/import_service.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main Import Service for Video Model Studio.
3
+ Delegates to specialized handler classes for different import types.
4
+ """
5
+
6
+ import logging
7
+ from typing import List, Dict, Optional, Tuple, Any, Union
8
+ from pathlib import Path
9
+ import gradio as gr
10
+
11
+ from huggingface_hub import HfApi
12
+
13
+ from .file_upload import FileUploadHandler
14
+ from .youtube import YouTubeDownloader
15
+ from .hub_dataset import HubDatasetBrowser
16
+ from vms.config import HF_API_TOKEN
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class ImportService:
21
+ """Main service class for handling imports from various sources"""
22
+
23
+ def __init__(self):
24
+ """Initialize the import service and handlers"""
25
+ self.hf_api = HfApi(token=HF_API_TOKEN)
26
+ self.file_handler = FileUploadHandler()
27
+ self.youtube_handler = YouTubeDownloader()
28
+ self.hub_browser = HubDatasetBrowser(self.hf_api)
29
+
30
+ def process_uploaded_files(self, file_paths: List[str]) -> str:
31
+ """Process uploaded file (ZIP, TAR, MP4, or image)
32
+
33
+ Args:
34
+ file_paths: File paths to the uploaded files from Gradio
35
+
36
+ Returns:
37
+ Status message string
38
+ """
39
+ if not file_paths or len(file_paths) == 0:
40
+ logger.warning("No files provided to process_uploaded_files")
41
+ return "No files provided"
42
+
43
+ return self.file_handler.process_uploaded_files(file_paths)
44
+
45
+ def download_youtube_video(self, url: str, progress=None) -> str:
46
+ """Download a video from YouTube
47
+
48
+ Args:
49
+ url: YouTube video URL
50
+ progress: Optional Gradio progress indicator
51
+
52
+ Returns:
53
+ Status message string
54
+ """
55
+ return self.youtube_handler.download_video(url, progress)
56
+
57
+ def search_datasets(self, query: str) -> List[List[str]]:
58
+ """Search for datasets on the Hugging Face Hub
59
+
60
+ Args:
61
+ query: Search query string
62
+
63
+ Returns:
64
+ List of datasets matching the query [id, title, downloads]
65
+ """
66
+ return self.hub_browser.search_datasets(query)
67
+
68
+ def get_dataset_info(self, dataset_id: str) -> Tuple[str, Dict[str, int], Dict[str, List[str]]]:
69
+ """Get detailed information about a dataset
70
+
71
+ Args:
72
+ dataset_id: The dataset ID to get information for
73
+
74
+ Returns:
75
+ Tuple of (markdown_info, file_counts, file_groups)
76
+ """
77
+ return self.hub_browser.get_dataset_info(dataset_id)
78
+
79
+ async def download_dataset(self, dataset_id: str, enable_splitting: bool = True) -> Tuple[str, str]:
80
+ """Download a dataset and process its video/image content
81
+
82
+ Args:
83
+ dataset_id: The dataset ID to download
84
+ enable_splitting: Whether to enable automatic video splitting
85
+
86
+ Returns:
87
+ Tuple of (loading_msg, status_msg)
88
+ """
89
+ return await self.hub_browser.download_dataset(dataset_id, enable_splitting)
90
+
91
+ async def download_file_group(self, dataset_id: str, file_type: str, enable_splitting: bool = True) -> str:
92
+ """Download a group of files (videos or WebDatasets)
93
+
94
+ Args:
95
+ dataset_id: The dataset ID
96
+ file_type: Type of file ("video" or "webdataset")
97
+ enable_splitting: Whether to enable automatic video splitting
98
+
99
+ Returns:
100
+ Status message
101
+ """
102
+ return await self.hub_browser.download_file_group(dataset_id, file_type, enable_splitting)
vms/services/importer/youtube.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ YouTube downloader for Video Model Studio.
3
+ Handles downloading videos from YouTube URLs.
4
+ """
5
+
6
+ import logging
7
+ import gradio as gr
8
+ from pathlib import Path
9
+ from typing import Optional, Any, Union, Callable
10
+
11
+ from pytubefix import YouTube
12
+
13
+ from vms.config import VIDEOS_TO_SPLIT_PATH
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class YouTubeDownloader:
18
+ """Handles downloading videos from YouTube"""
19
+
20
+ def download_video(self, url: str, progress: Optional[Callable] = None) -> str:
21
+ """Download a video from YouTube
22
+
23
+ Args:
24
+ url: YouTube video URL
25
+ progress: Optional Gradio progress indicator
26
+
27
+ Returns:
28
+ Status message string
29
+ """
30
+ if not url or not url.strip():
31
+ logger.warning("No YouTube URL provided")
32
+ return "Please enter a YouTube URL"
33
+
34
+ try:
35
+ logger.info(f"Downloading YouTube video: {url}")
36
+
37
+ # Extract video ID and create YouTube object
38
+ yt = YouTube(url, on_progress_callback=lambda stream, chunk, bytes_remaining:
39
+ progress((1 - bytes_remaining / stream.filesize), desc="Downloading...")
40
+ if progress else None)
41
+
42
+ video_id = yt.video_id
43
+ output_path = VIDEOS_TO_SPLIT_PATH / f"{video_id}.mp4"
44
+
45
+ # Download highest quality progressive MP4
46
+ if progress:
47
+ logger.debug("Getting video streams...")
48
+ progress(0, desc="Getting video streams...")
49
+ video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
50
+
51
+ if not video:
52
+ logger.error("Could not find a compatible video format")
53
+ gr.Error("Could not find a compatible video format")
54
+ return "Could not find a compatible video format"
55
+
56
+ # Download the video
57
+ if progress:
58
+ logger.info("Starting YouTube video download...")
59
+ progress(0, desc="Starting download...")
60
+
61
+ video.download(output_path=str(VIDEOS_TO_SPLIT_PATH), filename=f"{video_id}.mp4")
62
+
63
+ # Update UI
64
+ if progress:
65
+ logger.info("YouTube video download complete!")
66
+ gr.Info("YouTube video download complete!")
67
+ progress(1, desc="Download complete!")
68
+ return f"Successfully downloaded video: {yt.title}"
69
+
70
+ except Exception as e:
71
+ logger.error(f"Error downloading YouTube video: {str(e)}", exc_info=True)
72
+ gr.Error(f"Error downloading video: {str(e)}")
73
+ return f"Error downloading video: {str(e)}"
vms/services/monitoring.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ System monitoring service for Video Model Studio.
3
+ Tracks system resources like CPU, memory, and other metrics.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import logging
9
+ import platform
10
+ import threading
11
+ from datetime import datetime, timedelta
12
+ from collections import deque
13
+ from typing import Dict, List, Optional, Tuple, Any
14
+
15
+ import psutil
16
+
17
+ # Force the use of the Agg backend which is thread-safe
18
+ import matplotlib
19
+ matplotlib.use('Agg') # Must be before importing pyplot
20
+ import matplotlib.pyplot as plt
21
+
22
+ import numpy as np
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ class MonitoringService:
27
+ """Service for monitoring system resources and performance"""
28
+
29
+ def __init__(self, history_minutes: int = 10, sample_interval: int = 5):
30
+ """Initialize the monitoring service
31
+
32
+ Args:
33
+ history_minutes: How many minutes of history to keep
34
+ sample_interval: How many seconds between samples
35
+ """
36
+ self.history_minutes = history_minutes
37
+ self.sample_interval = sample_interval
38
+ self.max_samples = (history_minutes * 60) // sample_interval
39
+
40
+ # Initialize data structures for metrics
41
+ self.timestamps = deque(maxlen=self.max_samples)
42
+ self.cpu_percent = deque(maxlen=self.max_samples)
43
+ self.memory_percent = deque(maxlen=self.max_samples)
44
+ self.memory_used = deque(maxlen=self.max_samples)
45
+ self.memory_available = deque(maxlen=self.max_samples)
46
+
47
+ # CPU temperature history (might not be available on all systems)
48
+ self.cpu_temp = deque(maxlen=self.max_samples)
49
+
50
+ # Per-core CPU history
51
+ self.cpu_cores_percent = {}
52
+
53
+ # Track if the monitoring thread is running
54
+ self.is_running = False
55
+ self.thread = None
56
+
57
+ # Initialize with current values
58
+ self.collect_metrics()
59
+
60
+ def collect_metrics(self) -> Dict[str, Any]:
61
+ """Collect current system metrics
62
+
63
+ Returns:
64
+ Dictionary of current metrics
65
+ """
66
+ metrics = {
67
+ 'timestamp': datetime.now(),
68
+ 'cpu_percent': psutil.cpu_percent(interval=0.1),
69
+ 'memory_percent': psutil.virtual_memory().percent,
70
+ 'memory_used': psutil.virtual_memory().used / (1024**3), # GB
71
+ 'memory_available': psutil.virtual_memory().available / (1024**3), # GB
72
+ 'cpu_temp': None,
73
+ 'per_cpu_percent': psutil.cpu_percent(interval=0.1, percpu=True)
74
+ }
75
+
76
+ # Try to get CPU temperature (platform specific)
77
+ try:
78
+ if platform.system() == 'Linux':
79
+ # Try to get temperature from psutil
80
+ temps = psutil.sensors_temperatures()
81
+ for name, entries in temps.items():
82
+ if name.startswith(('coretemp', 'k10temp', 'cpu_thermal')):
83
+ metrics['cpu_temp'] = entries[0].current
84
+ break
85
+ elif platform.system() == 'Darwin': # macOS
86
+ # On macOS, we could use SMC reader but it requires additional dependencies
87
+ # Leaving as None for now
88
+ pass
89
+ elif platform.system() == 'Windows':
90
+ # Windows might require WMI, leaving as None for simplicity
91
+ pass
92
+ except (AttributeError, KeyError, IndexError, NotImplementedError):
93
+ # Sensors not available
94
+ pass
95
+
96
+ return metrics
97
+
98
+ def update_history(self, metrics: Dict[str, Any]) -> None:
99
+ """Update metric history with new values
100
+
101
+ Args:
102
+ metrics: New metrics to add to history
103
+ """
104
+ self.timestamps.append(metrics['timestamp'])
105
+ self.cpu_percent.append(metrics['cpu_percent'])
106
+ self.memory_percent.append(metrics['memory_percent'])
107
+ self.memory_used.append(metrics['memory_used'])
108
+ self.memory_available.append(metrics['memory_available'])
109
+
110
+ if metrics['cpu_temp'] is not None:
111
+ self.cpu_temp.append(metrics['cpu_temp'])
112
+
113
+ # Update per-core CPU metrics
114
+ for i, percent in enumerate(metrics['per_cpu_percent']):
115
+ if i not in self.cpu_cores_percent:
116
+ self.cpu_cores_percent[i] = deque(maxlen=self.max_samples)
117
+ self.cpu_cores_percent[i].append(percent)
118
+
119
+ def start_monitoring(self) -> None:
120
+ """Start background thread for collecting metrics"""
121
+ if self.is_running:
122
+ logger.warning("Monitoring thread already running")
123
+ return
124
+
125
+ self.is_running = True
126
+
127
+ def _monitor_loop():
128
+ while self.is_running:
129
+ try:
130
+ metrics = self.collect_metrics()
131
+ self.update_history(metrics)
132
+ time.sleep(self.sample_interval)
133
+ except Exception as e:
134
+ logger.error(f"Error in monitoring thread: {str(e)}", exc_info=True)
135
+ time.sleep(self.sample_interval)
136
+
137
+ self.thread = threading.Thread(target=_monitor_loop, daemon=True)
138
+ self.thread.start()
139
+ logger.info("System monitoring thread started")
140
+
141
+ def stop_monitoring(self) -> None:
142
+ """Stop the monitoring thread"""
143
+ if not self.is_running:
144
+ return
145
+
146
+ self.is_running = False
147
+ if self.thread:
148
+ self.thread.join(timeout=1.0)
149
+ logger.info("System monitoring thread stopped")
150
+
151
+ def get_current_metrics(self) -> Dict[str, Any]:
152
+ """Get current system metrics
153
+
154
+ Returns:
155
+ Dictionary with current system metrics
156
+ """
157
+ return self.collect_metrics()
158
+
159
+ def get_system_info(self) -> Dict[str, Any]:
160
+ """Get general system information
161
+
162
+ Returns:
163
+ Dictionary with system details
164
+ """
165
+ cpu_info = {
166
+ 'cores_physical': psutil.cpu_count(logical=False),
167
+ 'cores_logical': psutil.cpu_count(logical=True),
168
+ 'current_frequency': None,
169
+ 'architecture': platform.machine(),
170
+ }
171
+
172
+ # Try to get CPU frequency
173
+ try:
174
+ cpu_freq = psutil.cpu_freq()
175
+ if cpu_freq:
176
+ cpu_info['current_frequency'] = cpu_freq.current
177
+ except Exception:
178
+ pass
179
+
180
+ memory_info = {
181
+ 'total': psutil.virtual_memory().total / (1024**3), # GB
182
+ 'available': psutil.virtual_memory().available / (1024**3), # GB
183
+ 'used': psutil.virtual_memory().used / (1024**3), # GB
184
+ 'percent': psutil.virtual_memory().percent
185
+ }
186
+
187
+ disk_info = {}
188
+ for part in psutil.disk_partitions(all=False):
189
+ if os.name == 'nt' and ('cdrom' in part.opts or part.fstype == ''):
190
+ # Skip CD-ROM drives on Windows
191
+ continue
192
+ try:
193
+ usage = psutil.disk_usage(part.mountpoint)
194
+ disk_info[part.mountpoint] = {
195
+ 'total': usage.total / (1024**3), # GB
196
+ 'used': usage.used / (1024**3), # GB
197
+ 'free': usage.free / (1024**3), # GB
198
+ 'percent': usage.percent
199
+ }
200
+ except PermissionError:
201
+ continue
202
+
203
+ sys_info = {
204
+ 'system': platform.system(),
205
+ 'version': platform.version(),
206
+ 'platform': platform.platform(),
207
+ 'processor': platform.processor(),
208
+ 'hostname': platform.node(),
209
+ 'python_version': platform.python_version(),
210
+ 'uptime': time.time() - psutil.boot_time()
211
+ }
212
+
213
+ return {
214
+ 'cpu': cpu_info,
215
+ 'memory': memory_info,
216
+ 'disk': disk_info,
217
+ 'system': sys_info,
218
+ }
219
+
220
+ def generate_cpu_plot(self) -> plt.Figure:
221
+ """Generate a plot of CPU usage over time
222
+
223
+ Returns:
224
+ Matplotlib figure with CPU usage plot
225
+ """
226
+ fig, ax = plt.subplots(figsize=(10, 5))
227
+
228
+ if not self.timestamps:
229
+ ax.set_title("No CPU data available yet")
230
+ return fig
231
+
232
+ x = [t.strftime('%H:%M:%S') for t in self.timestamps]
233
+ if len(x) > 10:
234
+ # Show fewer x-axis labels for readability
235
+ step = len(x) // 10
236
+ ax.set_xticks(range(0, len(x), step))
237
+ ax.set_xticklabels([x[i] for i in range(0, len(x), step)])
238
+
239
+ ax.plot(x, list(self.cpu_percent), 'b-', label='CPU Usage %')
240
+
241
+ if self.cpu_temp and len(self.cpu_temp) > 0:
242
+ # Plot temperature on a secondary y-axis if available
243
+ ax2 = ax.twinx()
244
+ ax2.plot(x[:len(self.cpu_temp)], list(self.cpu_temp), 'r-', label='CPU Temp °C')
245
+ ax2.set_ylabel('Temperature (°C)', color='r')
246
+ ax2.tick_params(axis='y', colors='r')
247
+
248
+ ax.set_title('CPU Usage Over Time')
249
+ ax.set_xlabel('Time')
250
+ ax.set_ylabel('Usage %')
251
+ ax.grid(True, alpha=0.3)
252
+ ax.set_ylim(0, 100)
253
+
254
+ # Add legend
255
+ lines, labels = ax.get_legend_handles_labels()
256
+ if hasattr(locals(), 'ax2'):
257
+ lines2, labels2 = ax2.get_legend_handles_labels()
258
+ ax.legend(lines + lines2, labels + labels2, loc='upper left')
259
+ else:
260
+ ax.legend(loc='upper left')
261
+
262
+ plt.tight_layout()
263
+ return fig
264
+
265
+ def generate_memory_plot(self) -> plt.Figure:
266
+ """Generate a plot of memory usage over time
267
+
268
+ Returns:
269
+ Matplotlib figure with memory usage plot
270
+ """
271
+ fig, ax = plt.subplots(figsize=(10, 5))
272
+
273
+ if not self.timestamps:
274
+ ax.set_title("No memory data available yet")
275
+ return fig
276
+
277
+ x = [t.strftime('%H:%M:%S') for t in self.timestamps]
278
+ if len(x) > 10:
279
+ # Show fewer x-axis labels for readability
280
+ step = len(x) // 10
281
+ ax.set_xticks(range(0, len(x), step))
282
+ ax.set_xticklabels([x[i] for i in range(0, len(x), step)])
283
+
284
+ ax.plot(x, list(self.memory_percent), 'g-', label='Memory Usage %')
285
+
286
+ # Add secondary y-axis for absolute memory values
287
+ ax2 = ax.twinx()
288
+ ax2.plot(x, list(self.memory_used), 'm--', label='Used (GB)')
289
+ ax2.plot(x, list(self.memory_available), 'c--', label='Available (GB)')
290
+ ax2.set_ylabel('Memory (GB)')
291
+
292
+ ax.set_title('Memory Usage Over Time')
293
+ ax.set_xlabel('Time')
294
+ ax.set_ylabel('Usage %')
295
+ ax.grid(True, alpha=0.3)
296
+ ax.set_ylim(0, 100)
297
+
298
+ # Add legend
299
+ lines, labels = ax.get_legend_handles_labels()
300
+ lines2, labels2 = ax2.get_legend_handles_labels()
301
+ ax.legend(lines + lines2, labels + labels2, loc='upper left')
302
+
303
+ plt.tight_layout()
304
+ return fig
305
+
306
+ def generate_per_core_plot(self) -> plt.Figure:
307
+ """Generate a plot of per-core CPU usage
308
+
309
+ Returns:
310
+ Matplotlib figure with per-core CPU usage
311
+ """
312
+ num_cores = len(self.cpu_cores_percent)
313
+ if num_cores == 0:
314
+ # No data yet
315
+ fig, ax = plt.subplots(figsize=(10, 5))
316
+ ax.set_title("No per-core CPU data available yet")
317
+ return fig
318
+
319
+ # Determine grid layout based on number of cores
320
+ if num_cores <= 4:
321
+ rows, cols = 2, 2
322
+ elif num_cores <= 6:
323
+ rows, cols = 2, 3
324
+ elif num_cores <= 9:
325
+ rows, cols = 3, 3
326
+ elif num_cores <= 12:
327
+ rows, cols = 3, 4
328
+ else:
329
+ rows, cols = 4, 4
330
+
331
+ fig, axes = plt.subplots(rows, cols, figsize=(12, 8), sharex=True, sharey=True)
332
+ axes = axes.flatten()
333
+
334
+ x = [t.strftime('%H:%M:%S') for t in self.timestamps]
335
+ if len(x) > 5:
336
+ # Show fewer x-axis labels for readability
337
+ step = len(x) // 5
338
+ else:
339
+ step = 1
340
+
341
+ for i, (core_id, percentages) in enumerate(self.cpu_cores_percent.items()):
342
+ if i >= len(axes):
343
+ break
344
+
345
+ ax = axes[i]
346
+ ax.plot(x[:len(percentages)], list(percentages), 'b-')
347
+ ax.set_title(f'Core {core_id}')
348
+ ax.set_ylim(0, 100)
349
+ ax.grid(True, alpha=0.3)
350
+
351
+ # Add x-axis labels sparingly for readability
352
+ if i >= len(axes) - cols: # Only for bottom row
353
+ ax.set_xticks(range(0, len(x), step))
354
+ ax.set_xticklabels([x[i] for i in range(0, len(x), step)], rotation=45)
355
+
356
+ # Hide unused subplots
357
+ for i in range(num_cores, len(axes)):
358
+ axes[i].set_visible(False)
359
+
360
+ plt.tight_layout()
361
+ return fig
vms/tabs/__init__.py CHANGED
@@ -6,6 +6,7 @@ from .import_tab import ImportTab
6
  from .split_tab import SplitTab
7
  from .caption_tab import CaptionTab
8
  from .train_tab import TrainTab
 
9
  from .manage_tab import ManageTab
10
 
11
  __all__ = [
@@ -13,5 +14,6 @@ __all__ = [
13
  'SplitTab',
14
  'CaptionTab',
15
  'TrainTab',
 
16
  'ManageTab'
17
  ]
 
6
  from .split_tab import SplitTab
7
  from .caption_tab import CaptionTab
8
  from .train_tab import TrainTab
9
+ from .monitor_tab import MonitorTab
10
  from .manage_tab import ManageTab
11
 
12
  __all__ = [
 
14
  'SplitTab',
15
  'CaptionTab',
16
  'TrainTab',
17
+ 'MonitorTab',
18
  'ManageTab'
19
  ]
vms/tabs/import_tab/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Import tab for Video Model Studio.
3
+ """
4
+
5
+ from .upload_tab import UploadTab
6
+ from .youtube_tab import YouTubeTab
7
+ from .hub_tab import HubTab
8
+ from .import_tab import ImportTab
9
+
10
+ __all__ = ['UploadTab', 'YouTubeTab', 'HubTab', 'ImportTab']
vms/tabs/import_tab/hub_tab.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Hub tab for Video Model Studio UI.
3
+ Handles browsing, searching, and importing datasets from the Hugging Face Hub.
4
+ """
5
+
6
+ import gradio as gr
7
+ import logging
8
+ import asyncio
9
+ from pathlib import Path
10
+ from typing import Dict, Any, List, Optional, Tuple
11
+
12
+ from ..base_tab import BaseTab
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class HubTab(BaseTab):
17
+ """Hub tab for importing datasets from Hugging Face Hub"""
18
+
19
+ def __init__(self, app_state):
20
+ super().__init__(app_state)
21
+ self.id = "hub_tab"
22
+ self.title = "Import from Hugging Face"
23
+
24
+ def create(self, parent=None) -> gr.Tab:
25
+ """Create the Hub tab UI components"""
26
+ with gr.Tab(self.title, id=self.id) as tab:
27
+ with gr.Column():
28
+ with gr.Row():
29
+ gr.Markdown("## Import from Hub datasets")
30
+
31
+ with gr.Row():
32
+ gr.Markdown("Search for datasets with videos or WebDataset archives:")
33
+
34
+ with gr.Row():
35
+ self.components["dataset_search"] = gr.Textbox(
36
+ label="Search Hugging Face Datasets",
37
+ placeholder="Search for video datasets..."
38
+ )
39
+
40
+ with gr.Row():
41
+ self.components["dataset_search_btn"] = gr.Button("Search Datasets", variant="primary")
42
+
43
+ # Dataset browser results section
44
+ with gr.Row(visible=False) as dataset_results_row:
45
+ self.components["dataset_results_row"] = dataset_results_row
46
+
47
+ with gr.Column(scale=3):
48
+ self.components["dataset_results"] = gr.Dataframe(
49
+ headers=["id", "title", "downloads"],
50
+ interactive=False,
51
+ wrap=True,
52
+ row_count=10,
53
+ label="Dataset Results"
54
+ )
55
+
56
+ with gr.Column(scale=3):
57
+ # Dataset info and state
58
+ self.components["dataset_info"] = gr.Markdown("Select a dataset to see details")
59
+ self.components["dataset_id"] = gr.State(value=None)
60
+ self.components["file_type"] = gr.State(value=None)
61
+
62
+ # Files section that appears when a dataset is selected
63
+ with gr.Column(visible=False) as files_section:
64
+ self.components["files_section"] = files_section
65
+
66
+ gr.Markdown("## Files:")
67
+
68
+ # Video files row (appears if videos are present)
69
+ with gr.Row(visible=False) as video_files_row:
70
+ self.components["video_files_row"] = video_files_row
71
+
72
+ with gr.Column(scale=4):
73
+ self.components["video_count_text"] = gr.Markdown("Contains 0 video files")
74
+
75
+ with gr.Column(scale=1):
76
+ self.components["download_videos_btn"] = gr.Button("Download", variant="primary")
77
+
78
+ # WebDataset files row (appears if tar files are present)
79
+ with gr.Row(visible=False) as webdataset_files_row:
80
+ self.components["webdataset_files_row"] = webdataset_files_row
81
+
82
+ with gr.Column(scale=4):
83
+ self.components["webdataset_count_text"] = gr.Markdown("Contains 0 WebDataset (.tar) files")
84
+
85
+ with gr.Column(scale=1):
86
+ self.components["download_webdataset_btn"] = gr.Button("Download", variant="primary")
87
+
88
+ # Status and loading indicators
89
+ self.components["dataset_loading"] = gr.Markdown(visible=False)
90
+
91
+ return tab
92
+
93
+ def connect_events(self) -> None:
94
+ """Connect event handlers to UI components"""
95
+ # Dataset search event
96
+ self.components["dataset_search_btn"].click(
97
+ fn=self.search_datasets,
98
+ inputs=[self.components["dataset_search"]],
99
+ outputs=[
100
+ self.components["dataset_results"],
101
+ self.components["dataset_results_row"]
102
+ ]
103
+ )
104
+
105
+ # Dataset selection event - FIX HERE
106
+ self.components["dataset_results"].select(
107
+ fn=self.display_dataset_info,
108
+ outputs=[
109
+ self.components["dataset_info"],
110
+ self.components["dataset_id"],
111
+ self.components["files_section"],
112
+ self.components["video_files_row"],
113
+ self.components["video_count_text"],
114
+ self.components["webdataset_files_row"],
115
+ self.components["webdataset_count_text"]
116
+ ]
117
+ )
118
+
119
+ # Download videos button
120
+ self.components["download_videos_btn"].click(
121
+ fn=self.set_file_type_and_return,
122
+ outputs=[self.components["file_type"]]
123
+ ).then(
124
+ fn=self.download_file_group,
125
+ inputs=[
126
+ self.components["dataset_id"],
127
+ self.components["enable_automatic_video_split"],
128
+ self.components["file_type"]
129
+ ],
130
+ outputs=[
131
+ self.components["dataset_loading"],
132
+ self.components["import_status"]
133
+ ]
134
+ ).success(
135
+ fn=self.app.tabs["import_tab"].on_import_success,
136
+ inputs=[
137
+ self.components["enable_automatic_video_split"],
138
+ self.components["enable_automatic_content_captioning"],
139
+ self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
140
+ ],
141
+ outputs=[
142
+ self.app.tabs_component,
143
+ self.app.tabs["split_tab"].components["video_list"],
144
+ self.app.tabs["split_tab"].components["detect_status"]
145
+ ]
146
+ )
147
+
148
+ # Download WebDataset button
149
+ self.components["download_webdataset_btn"].click(
150
+ fn=self.set_file_type_and_return_webdataset,
151
+ outputs=[self.components["file_type"]]
152
+ ).then(
153
+ fn=self.download_file_group,
154
+ inputs=[
155
+ self.components["dataset_id"],
156
+ self.components["enable_automatic_video_split"],
157
+ self.components["file_type"]
158
+ ],
159
+ outputs=[
160
+ self.components["dataset_loading"],
161
+ self.components["import_status"]
162
+ ]
163
+ ).success(
164
+ fn=self.app.tabs["import_tab"].on_import_success,
165
+ inputs=[
166
+ self.components["enable_automatic_video_split"],
167
+ self.components["enable_automatic_content_captioning"],
168
+ self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
169
+ ],
170
+ outputs=[
171
+ self.app.tabs_component,
172
+ self.app.tabs["split_tab"].components["video_list"],
173
+ self.app.tabs["split_tab"].components["detect_status"]
174
+ ]
175
+ )
176
+
177
+ def set_file_type_and_return(self):
178
+ """Set file type to video and return it"""
179
+ return "video"
180
+
181
+ def set_file_type_and_return_webdataset(self):
182
+ """Set file type to webdataset and return it"""
183
+ return "webdataset"
184
+
185
+ def search_datasets(self, query: str):
186
+ """Search datasets on the Hub matching the query"""
187
+ try:
188
+ logger.info(f"Searching for datasets with query: '{query}'")
189
+ results = self.app.importer.search_datasets(query)
190
+ return results, gr.update(visible=True)
191
+ except Exception as e:
192
+ logger.error(f"Error searching datasets: {str(e)}", exc_info=True)
193
+ return [[f"Error: {str(e)}", "", ""]], gr.update(visible=True)
194
+
195
+ def display_dataset_info(self, evt: gr.SelectData):
196
+ """Display detailed information about the selected dataset"""
197
+ try:
198
+ if not evt or not evt.value:
199
+ logger.warning("No dataset selected in display_dataset_info")
200
+ return (
201
+ "No dataset selected", # dataset_info
202
+ None, # dataset_id
203
+ gr.update(visible=False), # files_section
204
+ gr.update(visible=False), # video_files_row
205
+ "", # video_count_text
206
+ gr.update(visible=False), # webdataset_files_row
207
+ "" # webdataset_count_text
208
+ )
209
+
210
+ dataset_id = evt.value[0] if isinstance(evt.value, list) else evt.value
211
+ logger.info(f"Getting dataset info for: {dataset_id}")
212
+
213
+ # Use the importer service to get dataset info
214
+ info_text, file_counts, _ = self.app.importer.get_dataset_info(dataset_id)
215
+
216
+ # Get counts of each file type
217
+ video_count = file_counts.get("video", 0)
218
+ webdataset_count = file_counts.get("webdataset", 0)
219
+
220
+ # Return all the required outputs individually
221
+ return (
222
+ info_text, # dataset_info
223
+ dataset_id, # dataset_id
224
+ gr.update(visible=True), # files_section
225
+ gr.update(visible=video_count > 0), # video_files_row
226
+ f"Contains {video_count} video file{'s' if video_count != 1 else ''}", # video_count_text
227
+ gr.update(visible=webdataset_count > 0), # webdataset_files_row
228
+ f"Contains {webdataset_count} WebDataset (.tar) file{'s' if webdataset_count != 1 else ''}" # webdataset_count_text
229
+ )
230
+ except Exception as e:
231
+ logger.error(f"Error displaying dataset info: {str(e)}", exc_info=True)
232
+ return (
233
+ f"Error loading dataset information: {str(e)}", # dataset_info
234
+ None, # dataset_id
235
+ gr.update(visible=False), # files_section
236
+ gr.update(visible=False), # video_files_row
237
+ "", # video_count_text
238
+ gr.update(visible=False), # webdataset_files_row
239
+ "" # webdataset_count_text
240
+ )
241
+
242
+ def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str) -> Tuple[gr.update, str]:
243
+ """Handle download of a group of files (videos or WebDatasets)"""
244
+ try:
245
+ if not dataset_id:
246
+ return gr.update(visible=False), "No dataset selected"
247
+
248
+ logger.info(f"Starting download of {file_type} files from dataset: {dataset_id}")
249
+
250
+ # Show loading indicator
251
+ loading_msg = gr.update(
252
+ value=f"## Downloading {file_type} files from {dataset_id}\n\nThis may take some time...",
253
+ visible=True
254
+ )
255
+ status_msg = f"Downloading {file_type} files from {dataset_id}..."
256
+
257
+ # Use the async version in a non-blocking way
258
+ asyncio.create_task(self._download_file_group_bg(dataset_id, file_type, enable_splitting))
259
+
260
+ return loading_msg, status_msg
261
+
262
+ except Exception as e:
263
+ error_msg = f"Error initiating download: {str(e)}"
264
+ logger.error(error_msg, exc_info=True)
265
+ return gr.update(visible=False), error_msg
266
+
267
+ async def _download_file_group_bg(self, dataset_id: str, file_type: str, enable_splitting: bool):
268
+ """Background task for group file download"""
269
+ try:
270
+ # This will execute in the background
271
+ await self.app.importer.download_file_group(dataset_id, file_type, enable_splitting)
272
+ except Exception as e:
273
+ logger.error(f"Error in background file group download: {str(e)}", exc_info=True)
vms/tabs/{import_tab.py → import_tab/import_tab.py} RENAMED
@@ -1,33 +1,42 @@
1
  """
2
- Import tab for Video Model Studio UI
3
  """
4
 
5
  import gradio as gr
6
  import logging
7
  import asyncio
8
  from pathlib import Path
9
- from typing import Dict, Any, List, Optional
10
 
11
- from .base_tab import BaseTab
12
- from ..config import (
13
- VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS
 
 
 
 
 
14
  )
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
  class ImportTab(BaseTab):
19
- """Import tab for uploading videos and images"""
20
 
21
  def __init__(self, app_state):
22
  super().__init__(app_state)
23
  self.id = "import_tab"
24
  self.title = "1️⃣ Import"
 
 
 
 
25
 
26
  def create(self, parent=None) -> gr.TabItem:
27
- """Create the Import tab UI components"""
28
  with gr.TabItem(self.title, id=self.id) as tab:
29
  with gr.Row():
30
- gr.Markdown("## Automatic splitting and captioning")
31
 
32
  with gr.Row():
33
  self.components["enable_automatic_video_split"] = gr.Checkbox(
@@ -42,38 +51,19 @@ class ImportTab(BaseTab):
42
  value=False,
43
  visible=True,
44
  )
 
 
 
 
 
 
 
45
 
46
- with gr.Row():
47
- with gr.Column(scale=3):
48
- with gr.Row():
49
- with gr.Column():
50
- gr.Markdown("## Import files")
51
- gr.Markdown("You can upload either:")
52
- gr.Markdown("- A single MP4 video file")
53
- gr.Markdown("- A ZIP archive containing multiple videos/images and optional caption files")
54
- gr.Markdown("- A WebDataset shard (.tar file)")
55
- gr.Markdown("- A ZIP archive containing WebDataset shards (.tar files)")
56
-
57
- with gr.Row():
58
- self.components["files"] = gr.Files(
59
- label="Upload Images, Videos, ZIP or WebDataset",
60
- file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip", ".tar"],
61
- type="filepath"
62
- )
63
-
64
- with gr.Column(scale=3):
65
- with gr.Row():
66
- with gr.Column():
67
- gr.Markdown("## Import a YouTube video")
68
- gr.Markdown("You can also use a YouTube video as reference, by pasting its URL here:")
69
-
70
- with gr.Row():
71
- self.components["youtube_url"] = gr.Textbox(
72
- label="Import YouTube Video",
73
- placeholder="https://www.youtube.com/watch?v=..."
74
- )
75
- with gr.Row():
76
- self.components["youtube_download_btn"] = gr.Button("Download YouTube Video", variant="secondary")
77
  with gr.Row():
78
  self.components["import_status"] = gr.Textbox(label="Status", interactive=False)
79
 
@@ -81,47 +71,17 @@ class ImportTab(BaseTab):
81
 
82
  def connect_events(self) -> None:
83
  """Connect event handlers to UI components"""
84
- # File upload event
85
- self.components["files"].upload(
86
- fn=lambda x: self.app.importer.process_uploaded_files(x),
87
- inputs=[self.components["files"]],
88
- outputs=[self.components["import_status"]]
89
- ).success(
90
- fn=self.update_titles_after_import,
91
- inputs=[
92
- self.components["enable_automatic_video_split"],
93
- self.components["enable_automatic_content_captioning"],
94
- self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
95
- ],
96
- outputs=[
97
- self.app.tabs_component, # Main tabs component
98
- self.app.tabs["split_tab"].components["video_list"],
99
- self.app.tabs["split_tab"].components["detect_status"],
100
- self.app.tabs["split_tab"].components["split_title"],
101
- self.app.tabs["caption_tab"].components["caption_title"],
102
- self.app.tabs["train_tab"].components["train_title"]
103
- ]
104
- )
105
-
106
- # YouTube download event
107
- self.components["youtube_download_btn"].click(
108
- fn=self.app.importer.download_youtube_video,
109
- inputs=[self.components["youtube_url"]],
110
- outputs=[self.components["import_status"]]
111
- ).success(
112
- fn=self.on_import_success,
113
- inputs=[
114
- self.components["enable_automatic_video_split"],
115
- self.components["enable_automatic_content_captioning"],
116
- self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
117
- ],
118
- outputs=[
119
- self.app.tabs_component,
120
- self.app.tabs["split_tab"].components["video_list"],
121
- self.app.tabs["split_tab"].components["detect_status"]
122
- ]
123
- )
124
-
125
  async def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
126
  """Handle successful import of files"""
127
  videos = self.app.tabs["split_tab"].list_unprocessed_videos()
 
1
  """
2
+ Parent import tab for Video Model Studio UI that contains sub-tabs
3
  """
4
 
5
  import gradio as gr
6
  import logging
7
  import asyncio
8
  from pathlib import Path
9
+ from typing import Dict, Any, List, Optional, Tuple
10
 
11
+ from ..base_tab import BaseTab
12
+ from .upload_tab import UploadTab
13
+ from .youtube_tab import YouTubeTab
14
+ from .hub_tab import HubTab
15
+
16
+ from vms.config import (
17
+ VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
18
+ STAGING_PATH
19
  )
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
  class ImportTab(BaseTab):
24
+ """Import tab for uploading videos and images, and browsing datasets"""
25
 
26
  def __init__(self, app_state):
27
  super().__init__(app_state)
28
  self.id = "import_tab"
29
  self.title = "1️⃣ Import"
30
+ # Initialize sub-tabs
31
+ self.upload_tab = UploadTab(app_state)
32
+ self.youtube_tab = YouTubeTab(app_state)
33
+ self.hub_tab = HubTab(app_state)
34
 
35
  def create(self, parent=None) -> gr.TabItem:
36
+ """Create the Import tab UI components with three sub-tabs"""
37
  with gr.TabItem(self.title, id=self.id) as tab:
38
  with gr.Row():
39
+ gr.Markdown("## Import settings")
40
 
41
  with gr.Row():
42
  self.components["enable_automatic_video_split"] = gr.Checkbox(
 
51
  value=False,
52
  visible=True,
53
  )
54
+
55
+ # Create tabs for different import methods
56
+ with gr.Tabs() as import_tabs:
57
+ # Create each sub-tab
58
+ self.upload_tab.create(import_tabs)
59
+ self.youtube_tab.create(import_tabs)
60
+ self.hub_tab.create(import_tabs)
61
 
62
+ # Store references to sub-tabs
63
+ self.components["upload_tab"] = self.upload_tab
64
+ self.components["youtube_tab"] = self.youtube_tab
65
+ self.components["hub_tab"] = self.hub_tab
66
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  with gr.Row():
68
  self.components["import_status"] = gr.Textbox(label="Status", interactive=False)
69
 
 
71
 
72
  def connect_events(self) -> None:
73
  """Connect event handlers to UI components"""
74
+ # Set shared components from parent tab to sub-tabs first
75
+ for subtab in [self.upload_tab, self.youtube_tab, self.hub_tab]:
76
+ subtab.components["import_status"] = self.components["import_status"]
77
+ subtab.components["enable_automatic_video_split"] = self.components["enable_automatic_video_split"]
78
+ subtab.components["enable_automatic_content_captioning"] = self.components["enable_automatic_content_captioning"]
79
+
80
+ # Then connect events for each sub-tab
81
+ self.upload_tab.connect_events()
82
+ self.youtube_tab.connect_events()
83
+ self.hub_tab.connect_events()
84
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  async def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
86
  """Handle successful import of files"""
87
  videos = self.app.tabs["split_tab"].list_unprocessed_videos()
vms/tabs/import_tab/upload_tab.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Upload tab for Video Model Studio UI.
3
+ Handles manual file uploads for videos, images, and archives.
4
+ """
5
+
6
+ import gradio as gr
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Dict, Any, Optional
10
+
11
+ from ..base_tab import BaseTab
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class UploadTab(BaseTab):
16
+ """Upload tab for manual file uploads"""
17
+
18
+ def __init__(self, app_state):
19
+ super().__init__(app_state)
20
+ self.id = "upload_tab"
21
+ self.title = "Manual Upload"
22
+
23
+ def create(self, parent=None) -> gr.Tab:
24
+ """Create the Upload tab UI components"""
25
+ with gr.Tab(self.title, id=self.id) as tab:
26
+ with gr.Column():
27
+ with gr.Row():
28
+ gr.Markdown("## Manual upload of video files")
29
+
30
+ with gr.Row():
31
+ with gr.Column():
32
+ with gr.Row():
33
+ gr.Markdown("You can upload either:")
34
+ with gr.Row():
35
+ gr.Markdown("- A single MP4 video file")
36
+ with gr.Row():
37
+ gr.Markdown("- A ZIP archive containing multiple videos/images and optional caption files")
38
+ with gr.Row():
39
+ gr.Markdown("- A WebDataset shard (.tar file)")
40
+ with gr.Row():
41
+ gr.Markdown("- A ZIP archive containing WebDataset shards (.tar files)")
42
+ with gr.Column():
43
+ with gr.Row():
44
+ self.components["files"] = gr.Files(
45
+ label="Upload Images, Videos, ZIP or WebDataset",
46
+ file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip", ".tar"],
47
+ type="filepath"
48
+ )
49
+
50
+ return tab
51
+
52
+ def connect_events(self) -> None:
53
+ """Connect event handlers to UI components"""
54
+ # File upload event
55
+ self.components["files"].upload(
56
+ fn=lambda x: self.app.importer.process_uploaded_files(x),
57
+ inputs=[self.components["files"]],
58
+ outputs=[self.components["import_status"]] # This comes from parent tab
59
+ ).success(
60
+ fn=self.app.tabs["import_tab"].update_titles_after_import,
61
+ inputs=[
62
+ self.components["enable_automatic_video_split"],
63
+ self.components["enable_automatic_content_captioning"],
64
+ self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
65
+ ],
66
+ outputs=[
67
+ self.app.tabs_component, # Main tabs component
68
+ self.app.tabs["split_tab"].components["video_list"],
69
+ self.app.tabs["split_tab"].components["detect_status"],
70
+ self.app.tabs["split_tab"].components["split_title"],
71
+ self.app.tabs["caption_tab"].components["caption_title"],
72
+ self.app.tabs["train_tab"].components["train_title"]
73
+ ]
74
+ )
vms/tabs/import_tab/youtube_tab.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ YouTube tab for Video Model Studio UI.
3
+ Handles downloading videos from YouTube URLs.
4
+ """
5
+
6
+ import gradio as gr
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Dict, Any, Optional
10
+
11
+ from ..base_tab import BaseTab
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class YouTubeTab(BaseTab):
16
+ """YouTube tab for downloading videos from YouTube"""
17
+
18
+ def __init__(self, app_state):
19
+ super().__init__(app_state)
20
+ self.id = "youtube_tab"
21
+ self.title = "Download from YouTube"
22
+
23
+ def create(self, parent=None) -> gr.Tab:
24
+ """Create the YouTube tab UI components"""
25
+ with gr.Tab(self.title, id=self.id) as tab:
26
+ with gr.Column():
27
+ with gr.Row():
28
+ gr.Markdown("## Import a YouTube video")
29
+
30
+ with gr.Row():
31
+ with gr.Column():
32
+ with gr.Row():
33
+ gr.Markdown("You can use a YouTube video as reference, by pasting its URL here:")
34
+ with gr.Row():
35
+ gr.Markdown("Please be aware of the [know limitations](https://stackoverflow.com/questions/78160027/how-to-solve-http-error-400-bad-request-in-pytube) and [issues](https://stackoverflow.com/questions/79226520/pytube-throws-http-error-403-forbidden-since-a-few-days)")
36
+
37
+ with gr.Column():
38
+ self.components["youtube_url"] = gr.Textbox(
39
+ label="Import YouTube Video",
40
+ placeholder="https://www.youtube.com/watch?v=..."
41
+ )
42
+
43
+ with gr.Row():
44
+ self.components["youtube_download_btn"] = gr.Button("Download YouTube Video", variant="primary")
45
+
46
+ return tab
47
+
48
+ def connect_events(self) -> None:
49
+ """Connect event handlers to UI components"""
50
+ # YouTube download event
51
+ self.components["youtube_download_btn"].click(
52
+ fn=self.app.importer.download_youtube_video,
53
+ inputs=[self.components["youtube_url"]],
54
+ outputs=[self.components["import_status"]] # This comes from parent tab
55
+ ).success(
56
+ fn=self.app.tabs["import_tab"].on_import_success,
57
+ inputs=[
58
+ self.components["enable_automatic_video_split"],
59
+ self.components["enable_automatic_content_captioning"],
60
+ self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
61
+ ],
62
+ outputs=[
63
+ self.app.tabs_component,
64
+ self.app.tabs["split_tab"].components["video_list"],
65
+ self.app.tabs["split_tab"].components["detect_status"]
66
+ ]
67
+ )
vms/tabs/manage_tab.py CHANGED
@@ -23,7 +23,7 @@ class ManageTab(BaseTab):
23
  def __init__(self, app_state):
24
  super().__init__(app_state)
25
  self.id = "manage_tab"
26
- self.title = "5️⃣ Manage"
27
 
28
  def create(self, parent=None) -> gr.TabItem:
29
  """Create the Manage tab UI components"""
 
23
  def __init__(self, app_state):
24
  super().__init__(app_state)
25
  self.id = "manage_tab"
26
+ self.title = "6️⃣ Manage"
27
 
28
  def create(self, parent=None) -> gr.TabItem:
29
  """Create the Manage tab UI components"""
vms/tabs/monitor_tab.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ System monitoring tab for Video Model Studio UI.
3
+ Displays system metrics like CPU, memory usage, and temperatures.
4
+ """
5
+
6
+ import gradio as gr
7
+ import time
8
+ import logging
9
+ from pathlib import Path
10
+ import os
11
+ import psutil
12
+ from typing import Dict, Any, List, Optional, Tuple
13
+ from datetime import datetime, timedelta
14
+
15
+ from .base_tab import BaseTab
16
+ from ..config import STORAGE_PATH
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ def get_folder_size(path):
21
+ """Calculate the total size of a folder in bytes"""
22
+ total_size = 0
23
+ for dirpath, dirnames, filenames in os.walk(path):
24
+ for filename in filenames:
25
+ file_path = os.path.join(dirpath, filename)
26
+ if not os.path.islink(file_path): # Skip symlinks
27
+ total_size += os.path.getsize(file_path)
28
+ return total_size
29
+
30
+ def human_readable_size(size_bytes):
31
+ """Convert a size in bytes to a human-readable string"""
32
+ if size_bytes == 0:
33
+ return "0 B"
34
+ size_names = ("B", "KB", "MB", "GB", "TB", "PB")
35
+ i = 0
36
+ while size_bytes >= 1024 and i < len(size_names) - 1:
37
+ size_bytes /= 1024
38
+ i += 1
39
+ return f"{size_bytes:.2f} {size_names[i]}"
40
+
41
+ class MonitorTab(BaseTab):
42
+ """Monitor tab for system resource monitoring"""
43
+
44
+ def __init__(self, app_state):
45
+ super().__init__(app_state)
46
+ self.id = "monitor_tab"
47
+ self.title = "4️⃣ Monitor"
48
+ self.refresh_interval = 2 # Changed from 5 to 2 seconds
49
+
50
+ def create(self, parent=None) -> gr.TabItem:
51
+ """Create the Monitor tab UI components"""
52
+ with gr.TabItem(self.title, id=self.id) as tab:
53
+ with gr.Row():
54
+ gr.Markdown("## System Monitoring")
55
+
56
+ # Current metrics
57
+ with gr.Row():
58
+ with gr.Column(scale=1):
59
+ self.components["current_metrics"] = gr.Markdown("Loading current metrics...")
60
+
61
+ # CPU and Memory charts in tabs
62
+ with gr.Tabs() as metrics_tabs:
63
+ with gr.Tab(label="CPU Usage") as cpu_tab:
64
+ self.components["cpu_plot"] = gr.Plot()
65
+
66
+ with gr.Tab(label="Memory Usage") as memory_tab:
67
+ self.components["memory_plot"] = gr.Plot()
68
+
69
+ with gr.Tab(label="Per-Core CPU") as per_core_tab:
70
+ self.components["per_core_plot"] = gr.Plot()
71
+
72
+ # System information summary in columns
73
+ with gr.Row():
74
+ with gr.Column(scale=1):
75
+ gr.Markdown("### System Information")
76
+ self.components["system_info"] = gr.Markdown("Loading system information...")
77
+
78
+ with gr.Column(scale=1):
79
+ gr.Markdown("### CPU Information")
80
+ self.components["cpu_info"] = gr.Markdown("Loading CPU information...")
81
+
82
+ with gr.Row():
83
+ with gr.Column(scale=1):
84
+ gr.Markdown("### Memory Information")
85
+ self.components["memory_info"] = gr.Markdown("Loading memory information...")
86
+
87
+ with gr.Column(scale=1):
88
+ gr.Markdown("### Storage Information")
89
+ self.components["storage_info"] = gr.Markdown("Loading storage information...")
90
+
91
+ # Toggle for enabling/disabling auto-refresh
92
+ with gr.Row():
93
+ self.components["auto_refresh"] = gr.Checkbox(
94
+ label=f"Auto refresh (every {self.refresh_interval} seconds)",
95
+ value=True,
96
+ info="Automatically refresh system metrics"
97
+ )
98
+ self.components["refresh_btn"] = gr.Button("Refresh Now")
99
+
100
+ # Timer for auto-refresh
101
+ self.components["refresh_timer"] = gr.Timer(
102
+ value=self.refresh_interval
103
+ )
104
+
105
+ return tab
106
+
107
+ def connect_events(self) -> None:
108
+ """Connect event handlers to UI components"""
109
+ # Manual refresh button
110
+ self.components["refresh_btn"].click(
111
+ fn=self.refresh_all,
112
+ outputs=[
113
+ self.components["system_info"],
114
+ self.components["cpu_info"],
115
+ self.components["memory_info"],
116
+ self.components["storage_info"],
117
+ self.components["current_metrics"],
118
+ self.components["cpu_plot"],
119
+ self.components["memory_plot"],
120
+ self.components["per_core_plot"]
121
+ ]
122
+ )
123
+
124
+ # Auto-refresh timer
125
+ self.components["refresh_timer"].tick(
126
+ fn=self.conditional_refresh,
127
+ inputs=[self.components["auto_refresh"]],
128
+ outputs=[
129
+ self.components["system_info"],
130
+ self.components["cpu_info"],
131
+ self.components["memory_info"],
132
+ self.components["storage_info"],
133
+ self.components["current_metrics"],
134
+ self.components["cpu_plot"],
135
+ self.components["memory_plot"],
136
+ self.components["per_core_plot"]
137
+ ]
138
+ )
139
+
140
+ def on_enter(self):
141
+ """Called when the tab is selected"""
142
+ # Start monitoring service if not already running
143
+ if not self.app.monitor.is_running:
144
+ self.app.monitor.start_monitoring()
145
+
146
+ # Trigger initial refresh
147
+ return self.refresh_all()
148
+
149
+ def conditional_refresh(self, auto_refresh: bool) -> Tuple:
150
+ """Only refresh if auto-refresh is enabled
151
+
152
+ Args:
153
+ auto_refresh: Whether auto-refresh is enabled
154
+
155
+ Returns:
156
+ Updated components or unchanged components
157
+ """
158
+ if auto_refresh:
159
+ return self.refresh_all()
160
+
161
+ # Return current values unchanged if auto-refresh is disabled
162
+ return (
163
+ self.components["system_info"].value,
164
+ self.components["cpu_info"].value,
165
+ self.components["memory_info"].value,
166
+ self.components["storage_info"].value,
167
+ self.components["current_metrics"].value,
168
+ self.components["cpu_plot"].value,
169
+ self.components["memory_plot"].value,
170
+ self.components["per_core_plot"].value
171
+ )
172
+
173
+ def refresh_all(self) -> Tuple:
174
+ """Refresh all monitoring components
175
+
176
+ Returns:
177
+ Updated values for all components
178
+ """
179
+ try:
180
+ # Get system info
181
+ system_info = self.app.monitor.get_system_info()
182
+
183
+ # Split system info into separate components
184
+ system_info_html = self.format_system_info(system_info)
185
+ cpu_info_html = self.format_cpu_info(system_info)
186
+ memory_info_html = self.format_memory_info(system_info)
187
+ storage_info_html = self.format_storage_info()
188
+
189
+ # Get current metrics
190
+ current_metrics = self.app.monitor.get_current_metrics()
191
+ metrics_html = self.format_current_metrics(current_metrics)
192
+
193
+ # Generate plots
194
+ cpu_plot = self.app.monitor.generate_cpu_plot()
195
+ memory_plot = self.app.monitor.generate_memory_plot()
196
+ per_core_plot = self.app.monitor.generate_per_core_plot()
197
+
198
+ return (
199
+ system_info_html,
200
+ cpu_info_html,
201
+ memory_info_html,
202
+ storage_info_html,
203
+ metrics_html,
204
+ cpu_plot,
205
+ memory_plot,
206
+ per_core_plot
207
+ )
208
+
209
+ except Exception as e:
210
+ logger.error(f"Error refreshing monitoring data: {str(e)}", exc_info=True)
211
+ error_msg = f"Error retrieving data: {str(e)}"
212
+ return (
213
+ error_msg,
214
+ error_msg,
215
+ error_msg,
216
+ error_msg,
217
+ error_msg,
218
+ None, None, None
219
+ )
220
+
221
+ def format_system_info(self, system_info: Dict[str, Any]) -> str:
222
+ """Format system information as HTML
223
+
224
+ Args:
225
+ system_info: System information dictionary
226
+
227
+ Returns:
228
+ Formatted HTML string
229
+ """
230
+ sys = system_info['system']
231
+ uptime_str = self.format_uptime(sys['uptime'])
232
+
233
+ html = f"""
234
+ **System:** {sys['system']} ({sys['platform']})
235
+ **Hostname:** {sys['hostname']}
236
+ **Uptime:** {uptime_str}
237
+ **Python Version:** {sys['python_version']}
238
+ """
239
+ return html
240
+
241
+ def format_cpu_info(self, system_info: Dict[str, Any]) -> str:
242
+ """Format CPU information as HTML
243
+
244
+ Args:
245
+ system_info: System information dictionary
246
+
247
+ Returns:
248
+ Formatted HTML string
249
+ """
250
+ cpu = system_info['cpu']
251
+ sys = system_info['system']
252
+
253
+ # Format CPU frequency
254
+ cpu_freq = "N/A"
255
+ if cpu['current_frequency']:
256
+ cpu_freq = f"{cpu['current_frequency'] / 1000:.2f} GHz"
257
+
258
+ html = f"""
259
+ **Processor:** {sys['processor'] or cpu['architecture']}
260
+ **Physical Cores:** {cpu['cores_physical']}
261
+ **Logical Cores:** {cpu['cores_logical']}
262
+ **Current Frequency:** {cpu_freq}
263
+ """
264
+ return html
265
+
266
+ def format_memory_info(self, system_info: Dict[str, Any]) -> str:
267
+ """Format memory information as HTML
268
+
269
+ Args:
270
+ system_info: System information dictionary
271
+
272
+ Returns:
273
+ Formatted HTML string
274
+ """
275
+ memory = system_info['memory']
276
+
277
+ html = f"""
278
+ **Total Memory:** {memory['total']:.2f} GB
279
+ **Available Memory:** {memory['available']:.2f} GB
280
+ **Used Memory:** {memory['used']:.2f} GB
281
+ **Usage:** {memory['percent']}%
282
+ """
283
+ return html
284
+
285
+ def format_storage_info(self) -> str:
286
+ """Format storage information as HTML, focused on STORAGE_PATH
287
+
288
+ Returns:
289
+ Formatted HTML string
290
+ """
291
+ try:
292
+ # Get total size of STORAGE_PATH
293
+ total_size = get_folder_size(STORAGE_PATH)
294
+ total_size_readable = human_readable_size(total_size)
295
+
296
+ html = f"**Total Storage Used:** {total_size_readable}\n\n"
297
+
298
+ # Get size of each subfolder
299
+ html += "**Subfolder Sizes:**\n\n"
300
+
301
+ for subfolder in sorted(STORAGE_PATH.iterdir()):
302
+ if subfolder.is_dir():
303
+ folder_size = get_folder_size(subfolder)
304
+ folder_size_readable = human_readable_size(folder_size)
305
+ percentage = (folder_size / total_size * 100) if total_size > 0 else 0
306
+
307
+ folder_name = subfolder.name
308
+ html += f"* **{folder_name}**: {folder_size_readable} ({percentage:.1f}%)\n"
309
+
310
+ return html
311
+
312
+ except Exception as e:
313
+ logger.error(f"Error getting folder sizes: {str(e)}", exc_info=True)
314
+ return f"Error getting folder sizes: {str(e)}"
315
+
316
+ def format_current_metrics(self, metrics: Dict[str, Any]) -> str:
317
+ """Format current metrics as HTML
318
+
319
+ Args:
320
+ metrics: Current metrics dictionary
321
+
322
+ Returns:
323
+ Formatted HTML string
324
+ """
325
+ timestamp = metrics['timestamp'].strftime('%Y-%m-%d %H:%M:%S')
326
+
327
+ # Style for CPU usage
328
+ cpu_style = "color: green;"
329
+ if metrics['cpu_percent'] > 90:
330
+ cpu_style = "color: red; font-weight: bold;"
331
+ elif metrics['cpu_percent'] > 70:
332
+ cpu_style = "color: orange;"
333
+
334
+ # Style for memory usage
335
+ mem_style = "color: green;"
336
+ if metrics['memory_percent'] > 90:
337
+ mem_style = "color: red; font-weight: bold;"
338
+ elif metrics['memory_percent'] > 70:
339
+ mem_style = "color: orange;"
340
+
341
+ # Temperature info
342
+ temp_html = ""
343
+ if metrics['cpu_temp'] is not None:
344
+ temp_style = "color: green;"
345
+ if metrics['cpu_temp'] > 80:
346
+ temp_style = "color: red; font-weight: bold;"
347
+ elif metrics['cpu_temp'] > 70:
348
+ temp_style = "color: orange;"
349
+
350
+ temp_html = f"""
351
+ **CPU Temperature:** <span style="{temp_style}">{metrics['cpu_temp']:.1f}°C</span>
352
+ """
353
+
354
+ html = f"""
355
+ **CPU Usage:** <span style="{cpu_style}">{metrics['cpu_percent']:.1f}%</span>
356
+ **Memory Usage:** <span style="{mem_style}">{metrics['memory_percent']:.1f}% ({metrics['memory_used']:.2f}/{metrics['memory_available']:.2f} GB)</span>
357
+ {temp_html}
358
+ """
359
+
360
+ # Add per-CPU core info
361
+ html += "\n"
362
+
363
+ per_cpu = metrics['per_cpu_percent']
364
+ cols = 4 # 4 cores per row
365
+
366
+ # Create a grid layout for cores
367
+ for i in range(0, len(per_cpu), cols):
368
+ row_cores = per_cpu[i:i+cols]
369
+ row_html = ""
370
+
371
+ for j, usage in enumerate(row_cores):
372
+ core_id = i + j
373
+ core_style = "color: green;"
374
+ if usage > 90:
375
+ core_style = "color: red; font-weight: bold;"
376
+ elif usage > 70:
377
+ core_style = "color: orange;"
378
+
379
+ row_html += f"**Core {core_id}:** <span style='{core_style}'>{usage:.1f}%</span>&nbsp;&nbsp;&nbsp;"
380
+
381
+ html += row_html + "\n"
382
+
383
+ return html
384
+
385
+ def format_uptime(self, seconds: float) -> str:
386
+ """Format uptime in seconds to a human-readable string
387
+
388
+ Args:
389
+ seconds: Uptime in seconds
390
+
391
+ Returns:
392
+ Formatted uptime string
393
+ """
394
+ days = int(seconds // 86400)
395
+ seconds %= 86400
396
+ hours = int(seconds // 3600)
397
+ seconds %= 3600
398
+ minutes = int(seconds // 60)
399
+
400
+ parts = []
401
+ if days > 0:
402
+ parts.append(f"{days} day{'s' if days != 1 else ''}")
403
+ if hours > 0 or days > 0:
404
+ parts.append(f"{hours} hour{'s' if hours != 1 else ''}")
405
+ parts.append(f"{minutes} minute{'s' if minutes != 1 else ''}")
406
+
407
+ return ", ".join(parts)
vms/ui/video_trainer_ui.py CHANGED
@@ -5,7 +5,7 @@ import logging
5
  import asyncio
6
  from typing import Any, Optional, Dict, List, Union, Tuple
7
 
8
- from ..services import TrainingService, CaptioningService, SplittingService, ImportService
9
  from ..config import (
10
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
11
  TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
@@ -28,7 +28,7 @@ from ..utils import (
28
  format_media_title,
29
  TrainingLogParser
30
  )
31
- from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
32
 
33
  logger = logging.getLogger(__name__)
34
  logger.setLevel(logging.INFO)
@@ -44,7 +44,11 @@ class VideoTrainerUI:
44
  self.splitter = SplittingService()
45
  self.importer = ImportService()
46
  self.captioner = CaptioningService()
47
-
 
 
 
 
48
  # Recovery status from any interrupted training
49
  recovery_result = self.trainer.recover_interrupted_training()
50
  # Add null check for recovery_result
@@ -81,6 +85,7 @@ class VideoTrainerUI:
81
  self.tabs["split_tab"] = SplitTab(self)
82
  self.tabs["caption_tab"] = CaptionTab(self)
83
  self.tabs["train_tab"] = TrainTab(self)
 
84
  self.tabs["manage_tab"] = ManageTab(self)
85
 
86
  # Create tab UI components
 
5
  import asyncio
6
  from typing import Any, Optional, Dict, List, Union, Tuple
7
 
8
+ from ..services import TrainingService, CaptioningService, SplittingService, ImportService, MonitoringService
9
  from ..config import (
10
  STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
11
  TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
 
28
  format_media_title,
29
  TrainingLogParser
30
  )
31
+ from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, MonitorTab, ManageTab
32
 
33
  logger = logging.getLogger(__name__)
34
  logger.setLevel(logging.INFO)
 
44
  self.splitter = SplittingService()
45
  self.importer = ImportService()
46
  self.captioner = CaptioningService()
47
+ self.monitor = MonitoringService()
48
+
49
+ # Start the monitoring service on app creation
50
+ self.monitor.start_monitoring()
51
+
52
  # Recovery status from any interrupted training
53
  recovery_result = self.trainer.recover_interrupted_training()
54
  # Add null check for recovery_result
 
85
  self.tabs["split_tab"] = SplitTab(self)
86
  self.tabs["caption_tab"] = CaptionTab(self)
87
  self.tabs["train_tab"] = TrainTab(self)
88
+ self.tabs["monitor_tab"] = MonitorTab(self)
89
  self.tabs["manage_tab"] = ManageTab(self)
90
 
91
  # Create tab UI components