jbilcke-hf HF staff commited on
Commit
aa1e877
·
1 Parent(s): bc6e5cf

working on hf dataset downloader

Browse files
app.py CHANGED
@@ -65,7 +65,7 @@ def main():
65
  ]
66
 
67
  # Launch the Gradio app
68
- app.queue(default_concurrency_limit=1).launch(
69
  server_name="0.0.0.0",
70
  allowed_paths=allowed_paths
71
  )
 
65
  ]
66
 
67
  # Launch the Gradio app
68
+ app.queue(default_concurrency_limit=2).launch(
69
  server_name="0.0.0.0",
70
  allowed_paths=allowed_paths
71
  )
vms/services/importer/hub_dataset.py CHANGED
@@ -10,7 +10,7 @@ import asyncio
10
  import logging
11
  import gradio as gr
12
  from pathlib import Path
13
- from typing import List, Dict, Optional, Tuple, Any, Union
14
 
15
  from huggingface_hub import (
16
  HfApi,
@@ -43,6 +43,7 @@ class HubDatasetBrowser:
43
 
44
  Returns:
45
  List of datasets matching the query [id, title, downloads]
 
46
  """
47
  try:
48
  # Start with some filters to find video-related datasets
@@ -126,15 +127,10 @@ class HubDatasetBrowser:
126
 
127
  # Add basic stats (with safer access)
128
  downloads = getattr(dataset_info, 'downloads', None)
129
- info_text += f"**Downloads:** {downloads if downloads is not None else 'N/A'}\n"
130
 
131
  last_modified = getattr(dataset_info, 'last_modified', None)
132
- info_text += f"**Last modified:** {last_modified if last_modified is not None else 'N/A'}\n"
133
-
134
- # Show tags if available (with safer access)
135
- tags = getattr(dataset_info, "tags", None) or []
136
- if tags:
137
- info_text += f"**Tags:** {', '.join(tags[:10])}\n\n"
138
 
139
  # Group files by type
140
  file_groups = {
@@ -168,13 +164,20 @@ class HubDatasetBrowser:
168
  logger.error(f"Error getting dataset info: {str(e)}", exc_info=True)
169
  return f"Error loading dataset information: {str(e)}", {}, {}
170
 
171
- async def download_file_group(self, dataset_id: str, file_type: str, enable_splitting: bool = True) -> str:
 
 
 
 
 
 
172
  """Download all files of a specific type from the dataset
173
 
174
  Args:
175
  dataset_id: The dataset ID
176
  file_type: Either "video" or "webdataset"
177
  enable_splitting: Whether to enable automatic video splitting
 
178
 
179
  Returns:
180
  Status message
@@ -190,6 +193,11 @@ class HubDatasetBrowser:
190
  return f"No {file_type} files found in the dataset"
191
 
192
  logger.info(f"Downloading {len(files)} {file_type} files from dataset {dataset_id}")
 
 
 
 
 
193
 
194
  # Track counts for status message
195
  video_count = 0
@@ -200,8 +208,16 @@ class HubDatasetBrowser:
200
  temp_path = Path(temp_dir)
201
 
202
  # Process all files of the requested type
203
- for filename in files:
204
  try:
 
 
 
 
 
 
 
 
205
  # Download the file
206
  file_path = hf_hub_download(
207
  repo_id=dataset_id,
@@ -212,6 +228,7 @@ class HubDatasetBrowser:
212
 
213
  file_path = Path(file_path)
214
  logger.info(f"Downloaded file to {file_path}")
 
215
 
216
  # Process based on file type
217
  if file_type == "video":
@@ -274,9 +291,13 @@ class HubDatasetBrowser:
274
  except Exception as e:
275
  logger.warning(f"Error processing file {filename}: {e}")
276
 
 
 
 
 
277
  # Generate status message
278
  if file_type == "video":
279
- return f"Successfully imported {video_count} videos from dataset {dataset_id}"
280
  elif file_type == "webdataset":
281
  parts = []
282
  if video_count > 0:
@@ -285,23 +306,37 @@ class HubDatasetBrowser:
285
  parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
286
 
287
  if parts:
288
- return f"Successfully imported {' and '.join(parts)} from WebDataset archives"
289
  else:
290
- return f"No media was found in the WebDataset archives"
 
 
291
 
292
- return f"Unknown file type: {file_type}"
 
 
 
 
 
293
 
294
  except Exception as e:
295
  error_msg = f"Error downloading {file_type} files: {str(e)}"
296
  logger.error(error_msg, exc_info=True)
 
297
  return error_msg
298
 
299
- async def download_dataset(self, dataset_id: str, enable_splitting: bool = True) -> Tuple[str, str]:
 
 
 
 
 
300
  """Download a dataset and process its video/image content
301
 
302
  Args:
303
  dataset_id: The dataset ID to download
304
  enable_splitting: Whether to enable automatic video splitting
 
305
 
306
  Returns:
307
  Tuple of (loading_msg, status_msg)
@@ -327,9 +362,15 @@ class HubDatasetBrowser:
327
  video_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith((".mp4", ".webm"))]
328
  tar_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith(".tar")]
329
 
 
 
 
 
 
330
  # Create a temporary directory for downloads
331
  with tempfile.TemporaryDirectory() as temp_dir:
332
  temp_path = Path(temp_dir)
 
333
 
334
  # If we have video files, download them individually
335
  if video_files:
@@ -337,6 +378,14 @@ class HubDatasetBrowser:
337
  logger.info(f"Downloading {len(video_files)} video files from {dataset_id}")
338
 
339
  for i, video_file in enumerate(video_files):
 
 
 
 
 
 
 
 
340
  # Download the video file
341
  try:
342
  file_path = hf_hub_download(
@@ -369,6 +418,7 @@ class HubDatasetBrowser:
369
 
370
  status_msg = f"Downloaded video {i+1}/{len(video_files)} from {dataset_id}"
371
  logger.info(status_msg)
 
372
  except Exception as e:
373
  logger.warning(f"Error downloading {video_file}: {e}")
374
 
@@ -378,6 +428,14 @@ class HubDatasetBrowser:
378
  logger.info(f"Downloading {len(tar_files)} WebDataset files from {dataset_id}")
379
 
380
  for i, tar_file in enumerate(tar_files):
 
 
 
 
 
 
 
 
381
  try:
382
  file_path = hf_hub_download(
383
  repo_id=dataset_id,
@@ -387,6 +445,7 @@ class HubDatasetBrowser:
387
  )
388
  status_msg = f"Downloaded WebDataset {i+1}/{len(tar_files)} from {dataset_id}"
389
  logger.info(status_msg)
 
390
  except Exception as e:
391
  logger.warning(f"Error downloading {tar_file}: {e}")
392
 
@@ -395,6 +454,9 @@ class HubDatasetBrowser:
395
  loading_msg = f"{loading_msg}\n\nDownloading entire dataset repository..."
396
  logger.info(f"No specific media files found, downloading entire repository for {dataset_id}")
397
 
 
 
 
398
  try:
399
  snapshot_download(
400
  repo_id=dataset_id,
@@ -403,6 +465,9 @@ class HubDatasetBrowser:
403
  )
404
  status_msg = f"Downloaded entire repository for {dataset_id}"
405
  logger.info(status_msg)
 
 
 
406
  except Exception as e:
407
  logger.error(f"Error downloading dataset snapshot: {e}", exc_info=True)
408
  return loading_msg, f"Error downloading dataset: {str(e)}"
@@ -411,6 +476,9 @@ class HubDatasetBrowser:
411
  loading_msg = f"{loading_msg}\n\nProcessing downloaded files..."
412
  logger.info(f"Processing downloaded files from {dataset_id}")
413
 
 
 
 
414
  # Count imported files
415
  video_count = 0
416
  image_count = 0
@@ -420,11 +488,28 @@ class HubDatasetBrowser:
420
  async def process_files():
421
  nonlocal video_count, image_count, tar_count
422
 
 
 
 
 
 
 
 
423
  # Process all files in the temp directory
424
  for root, _, files in os.walk(temp_path):
425
  for file in files:
426
  file_path = Path(root) / file
427
 
 
 
 
 
 
 
 
 
 
 
428
  # Process videos
429
  if file.lower().endswith((".mp4", ".webm")):
430
  # Choose target path based on auto-splitting setting
@@ -490,10 +575,16 @@ class HubDatasetBrowser:
490
  logger.info(f"Extracted {vid_count} videos and {img_count} images from {file}")
491
  except Exception as e:
492
  logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
 
 
493
 
494
  # Run the processing asynchronously
495
  await process_files()
496
 
 
 
 
 
497
  # Generate final status message
498
  parts = []
499
  if video_count > 0:
 
10
  import logging
11
  import gradio as gr
12
  from pathlib import Path
13
+ from typing import List, Dict, Optional, Tuple, Any, Union, Callable
14
 
15
  from huggingface_hub import (
16
  HfApi,
 
43
 
44
  Returns:
45
  List of datasets matching the query [id, title, downloads]
46
+ Note: We still return all columns internally, but the UI will only display the first column
47
  """
48
  try:
49
  # Start with some filters to find video-related datasets
 
127
 
128
  # Add basic stats (with safer access)
129
  downloads = getattr(dataset_info, 'downloads', None)
130
+ info_text += f"## Downloads: {downloads if downloads is not None else 'N/A'}\n"
131
 
132
  last_modified = getattr(dataset_info, 'last_modified', None)
133
+ info_text += f"## Last modified: {last_modified if last_modified is not None else 'N/A'}\n"
 
 
 
 
 
134
 
135
  # Group files by type
136
  file_groups = {
 
164
  logger.error(f"Error getting dataset info: {str(e)}", exc_info=True)
165
  return f"Error loading dataset information: {str(e)}", {}, {}
166
 
167
+ async def download_file_group(
168
+ self,
169
+ dataset_id: str,
170
+ file_type: str,
171
+ enable_splitting: bool = True,
172
+ progress_callback: Optional[Callable] = None
173
+ ) -> str:
174
  """Download all files of a specific type from the dataset
175
 
176
  Args:
177
  dataset_id: The dataset ID
178
  file_type: Either "video" or "webdataset"
179
  enable_splitting: Whether to enable automatic video splitting
180
+ progress_callback: Optional callback for progress updates
181
 
182
  Returns:
183
  Status message
 
193
  return f"No {file_type} files found in the dataset"
194
 
195
  logger.info(f"Downloading {len(files)} {file_type} files from dataset {dataset_id}")
196
+ gr.Info(f"Starting download of {len(files)} {file_type} files from {dataset_id}")
197
+
198
+ # Initialize progress if callback provided
199
+ if progress_callback:
200
+ progress_callback(0, desc=f"Starting download of {len(files)} {file_type} files", total=len(files))
201
 
202
  # Track counts for status message
203
  video_count = 0
 
208
  temp_path = Path(temp_dir)
209
 
210
  # Process all files of the requested type
211
+ for i, filename in enumerate(files):
212
  try:
213
+ # Update progress
214
+ if progress_callback:
215
+ progress_callback(
216
+ i,
217
+ desc=f"Downloading file {i+1}/{len(files)}: {Path(filename).name}",
218
+ total=len(files)
219
+ )
220
+
221
  # Download the file
222
  file_path = hf_hub_download(
223
  repo_id=dataset_id,
 
228
 
229
  file_path = Path(file_path)
230
  logger.info(f"Downloaded file to {file_path}")
231
+ #gr.Info(f"Downloaded {file_path.name} ({i+1}/{len(files)})")
232
 
233
  # Process based on file type
234
  if file_type == "video":
 
291
  except Exception as e:
292
  logger.warning(f"Error processing file {filename}: {e}")
293
 
294
+ # Update progress to complete
295
+ if progress_callback:
296
+ progress_callback(len(files), desc="Download complete", total=len(files))
297
+
298
  # Generate status message
299
  if file_type == "video":
300
+ status_msg = f"Successfully imported {video_count} videos from dataset {dataset_id}"
301
  elif file_type == "webdataset":
302
  parts = []
303
  if video_count > 0:
 
306
  parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
307
 
308
  if parts:
309
+ status_msg = f"Successfully imported {' and '.join(parts)} from WebDataset archives"
310
  else:
311
+ status_msg = f"No media was found in the WebDataset archives"
312
+ else:
313
+ status_msg = f"Unknown file type: {file_type}"
314
 
315
+ # Final notification
316
+ logger.info(f"✅ Download complete! {status_msg}")
317
+ # This info message will appear as a toast notification
318
+ gr.Info(f"✅ Download complete! {status_msg}")
319
+
320
+ return status_msg
321
 
322
  except Exception as e:
323
  error_msg = f"Error downloading {file_type} files: {str(e)}"
324
  logger.error(error_msg, exc_info=True)
325
+ gr.Error(error_msg)
326
  return error_msg
327
 
328
+ async def download_dataset(
329
+ self,
330
+ dataset_id: str,
331
+ enable_splitting: bool = True,
332
+ progress_callback: Optional[Callable] = None
333
+ ) -> Tuple[str, str]:
334
  """Download a dataset and process its video/image content
335
 
336
  Args:
337
  dataset_id: The dataset ID to download
338
  enable_splitting: Whether to enable automatic video splitting
339
+ progress_callback: Optional callback for progress tracking
340
 
341
  Returns:
342
  Tuple of (loading_msg, status_msg)
 
362
  video_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith((".mp4", ".webm"))]
363
  tar_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith(".tar")]
364
 
365
+ # Initialize progress tracking
366
+ total_files = len(video_files) + len(tar_files)
367
+ if progress_callback:
368
+ progress_callback(0, desc=f"Starting download of dataset: {dataset_id}", total=total_files)
369
+
370
  # Create a temporary directory for downloads
371
  with tempfile.TemporaryDirectory() as temp_dir:
372
  temp_path = Path(temp_dir)
373
+ files_processed = 0
374
 
375
  # If we have video files, download them individually
376
  if video_files:
 
378
  logger.info(f"Downloading {len(video_files)} video files from {dataset_id}")
379
 
380
  for i, video_file in enumerate(video_files):
381
+ # Update progress
382
+ if progress_callback:
383
+ progress_callback(
384
+ files_processed,
385
+ desc=f"Downloading video {i+1}/{len(video_files)}: {Path(video_file).name}",
386
+ total=total_files
387
+ )
388
+
389
  # Download the video file
390
  try:
391
  file_path = hf_hub_download(
 
418
 
419
  status_msg = f"Downloaded video {i+1}/{len(video_files)} from {dataset_id}"
420
  logger.info(status_msg)
421
+ files_processed += 1
422
  except Exception as e:
423
  logger.warning(f"Error downloading {video_file}: {e}")
424
 
 
428
  logger.info(f"Downloading {len(tar_files)} WebDataset files from {dataset_id}")
429
 
430
  for i, tar_file in enumerate(tar_files):
431
+ # Update progress
432
+ if progress_callback:
433
+ progress_callback(
434
+ files_processed,
435
+ desc=f"Downloading WebDataset {i+1}/{len(tar_files)}: {Path(tar_file).name}",
436
+ total=total_files
437
+ )
438
+
439
  try:
440
  file_path = hf_hub_download(
441
  repo_id=dataset_id,
 
445
  )
446
  status_msg = f"Downloaded WebDataset {i+1}/{len(tar_files)} from {dataset_id}"
447
  logger.info(status_msg)
448
+ files_processed += 1
449
  except Exception as e:
450
  logger.warning(f"Error downloading {tar_file}: {e}")
451
 
 
454
  loading_msg = f"{loading_msg}\n\nDownloading entire dataset repository..."
455
  logger.info(f"No specific media files found, downloading entire repository for {dataset_id}")
456
 
457
+ if progress_callback:
458
+ progress_callback(0, desc=f"Downloading entire repository for {dataset_id}", total=1)
459
+
460
  try:
461
  snapshot_download(
462
  repo_id=dataset_id,
 
465
  )
466
  status_msg = f"Downloaded entire repository for {dataset_id}"
467
  logger.info(status_msg)
468
+
469
+ if progress_callback:
470
+ progress_callback(1, desc="Repository download complete", total=1)
471
  except Exception as e:
472
  logger.error(f"Error downloading dataset snapshot: {e}", exc_info=True)
473
  return loading_msg, f"Error downloading dataset: {str(e)}"
 
476
  loading_msg = f"{loading_msg}\n\nProcessing downloaded files..."
477
  logger.info(f"Processing downloaded files from {dataset_id}")
478
 
479
+ if progress_callback:
480
+ progress_callback(0, desc="Processing downloaded files", total=100)
481
+
482
  # Count imported files
483
  video_count = 0
484
  image_count = 0
 
488
  async def process_files():
489
  nonlocal video_count, image_count, tar_count
490
 
491
+ # Get total number of files to process
492
+ file_count = 0
493
+ for root, _, files in os.walk(temp_path):
494
+ file_count += len(files)
495
+
496
+ processed = 0
497
+
498
  # Process all files in the temp directory
499
  for root, _, files in os.walk(temp_path):
500
  for file in files:
501
  file_path = Path(root) / file
502
 
503
+ # Update progress (every 5 files to avoid too many updates)
504
+ if progress_callback and processed % 5 == 0:
505
+ if file_count > 0:
506
+ progress_percent = int((processed / file_count) * 100)
507
+ progress_callback(
508
+ progress_percent,
509
+ desc=f"Processing files: {processed}/{file_count}",
510
+ total=100
511
+ )
512
+
513
  # Process videos
514
  if file.lower().endswith((".mp4", ".webm")):
515
  # Choose target path based on auto-splitting setting
 
575
  logger.info(f"Extracted {vid_count} videos and {img_count} images from {file}")
576
  except Exception as e:
577
  logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
578
+
579
+ processed += 1
580
 
581
  # Run the processing asynchronously
582
  await process_files()
583
 
584
+ # Update progress to complete
585
+ if progress_callback:
586
+ progress_callback(100, desc="Processing complete", total=100)
587
+
588
  # Generate final status message
589
  parts = []
590
  if video_count > 0:
vms/services/importer/import_service.py CHANGED
@@ -4,7 +4,7 @@ Delegates to specialized handler classes for different import types.
4
  """
5
 
6
  import logging
7
- from typing import List, Dict, Optional, Tuple, Any, Union
8
  from pathlib import Path
9
  import gradio as gr
10
 
@@ -76,27 +76,40 @@ class ImportService:
76
  """
77
  return self.hub_browser.get_dataset_info(dataset_id)
78
 
79
- async def download_dataset(self, dataset_id: str, enable_splitting: bool = True) -> Tuple[str, str]:
 
 
 
 
 
80
  """Download a dataset and process its video/image content
81
 
82
  Args:
83
  dataset_id: The dataset ID to download
84
  enable_splitting: Whether to enable automatic video splitting
 
85
 
86
  Returns:
87
  Tuple of (loading_msg, status_msg)
88
  """
89
- return await self.hub_browser.download_dataset(dataset_id, enable_splitting)
90
 
91
- async def download_file_group(self, dataset_id: str, file_type: str, enable_splitting: bool = True) -> str:
 
 
 
 
 
 
92
  """Download a group of files (videos or WebDatasets)
93
 
94
  Args:
95
  dataset_id: The dataset ID
96
  file_type: Type of file ("video" or "webdataset")
97
  enable_splitting: Whether to enable automatic video splitting
 
98
 
99
  Returns:
100
  Status message
101
  """
102
- return await self.hub_browser.download_file_group(dataset_id, file_type, enable_splitting)
 
4
  """
5
 
6
  import logging
7
+ from typing import List, Dict, Optional, Tuple, Any, Union, Callable
8
  from pathlib import Path
9
  import gradio as gr
10
 
 
76
  """
77
  return self.hub_browser.get_dataset_info(dataset_id)
78
 
79
+ async def download_dataset(
80
+ self,
81
+ dataset_id: str,
82
+ enable_splitting: bool = True,
83
+ progress_callback: Optional[Callable] = None
84
+ ) -> Tuple[str, str]:
85
  """Download a dataset and process its video/image content
86
 
87
  Args:
88
  dataset_id: The dataset ID to download
89
  enable_splitting: Whether to enable automatic video splitting
90
+ progress_callback: Optional callback for progress tracking
91
 
92
  Returns:
93
  Tuple of (loading_msg, status_msg)
94
  """
95
+ return await self.hub_browser.download_dataset(dataset_id, enable_splitting, progress_callback)
96
 
97
+ async def download_file_group(
98
+ self,
99
+ dataset_id: str,
100
+ file_type: str,
101
+ enable_splitting: bool = True,
102
+ progress_callback: Optional[Callable] = None
103
+ ) -> str:
104
  """Download a group of files (videos or WebDatasets)
105
 
106
  Args:
107
  dataset_id: The dataset ID
108
  file_type: Type of file ("video" or "webdataset")
109
  enable_splitting: Whether to enable automatic video splitting
110
+ progress_callback: Optional callback for progress tracking
111
 
112
  Returns:
113
  Status message
114
  """
115
+ return await self.hub_browser.download_file_group(dataset_id, file_type, enable_splitting, progress_callback)
vms/tabs/import_tab/hub_tab.py CHANGED
@@ -6,6 +6,7 @@ Handles browsing, searching, and importing datasets from the Hugging Face Hub.
6
  import gradio as gr
7
  import logging
8
  import asyncio
 
9
  from pathlib import Path
10
  from typing import Dict, Any, List, Optional, Tuple
11
 
@@ -20,6 +21,7 @@ class HubTab(BaseTab):
20
  super().__init__(app_state)
21
  self.id = "hub_tab"
22
  self.title = "Import from Hugging Face"
 
23
 
24
  def create(self, parent=None) -> gr.Tab:
25
  """Create the Hub tab UI components"""
@@ -33,8 +35,8 @@ class HubTab(BaseTab):
33
 
34
  with gr.Row():
35
  self.components["dataset_search"] = gr.Textbox(
36
- label="Search Hugging Face Datasets",
37
- placeholder="Search for video datasets..."
38
  )
39
 
40
  with gr.Row():
@@ -46,7 +48,7 @@ class HubTab(BaseTab):
46
 
47
  with gr.Column(scale=3):
48
  self.components["dataset_results"] = gr.Dataframe(
49
- headers=["id", "title", "downloads"],
50
  interactive=False,
51
  wrap=True,
52
  row_count=10,
@@ -58,6 +60,7 @@ class HubTab(BaseTab):
58
  self.components["dataset_info"] = gr.Markdown("Select a dataset to see details")
59
  self.components["dataset_id"] = gr.State(value=None)
60
  self.components["file_type"] = gr.State(value=None)
 
61
 
62
  # Files section that appears when a dataset is selected
63
  with gr.Column(visible=False) as files_section:
@@ -66,27 +69,23 @@ class HubTab(BaseTab):
66
  gr.Markdown("## Files:")
67
 
68
  # Video files row (appears if videos are present)
69
- with gr.Row(visible=False) as video_files_row:
70
  self.components["video_files_row"] = video_files_row
71
 
72
- with gr.Column(scale=4):
73
- self.components["video_count_text"] = gr.Markdown("Contains 0 video files")
74
 
75
- with gr.Column(scale=1):
76
- self.components["download_videos_btn"] = gr.Button("Download", variant="primary")
77
 
78
  # WebDataset files row (appears if tar files are present)
79
- with gr.Row(visible=False) as webdataset_files_row:
80
  self.components["webdataset_files_row"] = webdataset_files_row
81
 
82
- with gr.Column(scale=4):
83
- self.components["webdataset_count_text"] = gr.Markdown("Contains 0 WebDataset (.tar) files")
84
 
85
- with gr.Column(scale=1):
86
- self.components["download_webdataset_btn"] = gr.Button("Download", variant="primary")
87
 
88
- # Status and loading indicators
89
- self.components["dataset_loading"] = gr.Markdown(visible=False)
90
 
91
  return tab
92
 
@@ -102,7 +101,7 @@ class HubTab(BaseTab):
102
  ]
103
  )
104
 
105
- # Dataset selection event - FIX HERE
106
  self.components["dataset_results"].select(
107
  fn=self.display_dataset_info,
108
  outputs=[
@@ -112,7 +111,8 @@ class HubTab(BaseTab):
112
  self.components["video_files_row"],
113
  self.components["video_count_text"],
114
  self.components["webdataset_files_row"],
115
- self.components["webdataset_count_text"]
 
116
  ]
117
  )
118
 
@@ -128,20 +128,11 @@ class HubTab(BaseTab):
128
  self.components["file_type"]
129
  ],
130
  outputs=[
131
- self.components["dataset_loading"],
132
- self.components["import_status"]
133
- ]
134
- ).success(
135
- fn=self.app.tabs["import_tab"].on_import_success,
136
- inputs=[
137
- self.components["enable_automatic_video_split"],
138
- self.components["enable_automatic_content_captioning"],
139
- self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
140
- ],
141
- outputs=[
142
- self.app.tabs_component,
143
- self.app.tabs["split_tab"].components["video_list"],
144
- self.app.tabs["split_tab"].components["detect_status"]
145
  ]
146
  )
147
 
@@ -157,20 +148,11 @@ class HubTab(BaseTab):
157
  self.components["file_type"]
158
  ],
159
  outputs=[
160
- self.components["dataset_loading"],
161
- self.components["import_status"]
162
- ]
163
- ).success(
164
- fn=self.app.tabs["import_tab"].on_import_success,
165
- inputs=[
166
- self.components["enable_automatic_video_split"],
167
- self.components["enable_automatic_content_captioning"],
168
- self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
169
- ],
170
- outputs=[
171
- self.app.tabs_component,
172
- self.app.tabs["split_tab"].components["video_list"],
173
- self.app.tabs["split_tab"].components["detect_status"]
174
  ]
175
  )
176
 
@@ -186,12 +168,16 @@ class HubTab(BaseTab):
186
  """Search datasets on the Hub matching the query"""
187
  try:
188
  logger.info(f"Searching for datasets with query: '{query}'")
189
- results = self.app.importer.search_datasets(query)
 
 
 
 
190
  return results, gr.update(visible=True)
191
  except Exception as e:
192
  logger.error(f"Error searching datasets: {str(e)}", exc_info=True)
193
- return [[f"Error: {str(e)}", "", ""]], gr.update(visible=True)
194
-
195
  def display_dataset_info(self, evt: gr.SelectData):
196
  """Display detailed information about the selected dataset"""
197
  try:
@@ -204,9 +190,11 @@ class HubTab(BaseTab):
204
  gr.update(visible=False), # video_files_row
205
  "", # video_count_text
206
  gr.update(visible=False), # webdataset_files_row
207
- "" # webdataset_count_text
 
208
  )
209
 
 
210
  dataset_id = evt.value[0] if isinstance(evt.value, list) else evt.value
211
  logger.info(f"Getting dataset info for: {dataset_id}")
212
 
@@ -225,7 +213,8 @@ class HubTab(BaseTab):
225
  gr.update(visible=video_count > 0), # video_files_row
226
  f"Contains {video_count} video file{'s' if video_count != 1 else ''}", # video_count_text
227
  gr.update(visible=webdataset_count > 0), # webdataset_files_row
228
- f"Contains {webdataset_count} WebDataset (.tar) file{'s' if webdataset_count != 1 else ''}" # webdataset_count_text
 
229
  )
230
  except Exception as e:
231
  logger.error(f"Error displaying dataset info: {str(e)}", exc_info=True)
@@ -236,38 +225,91 @@ class HubTab(BaseTab):
236
  gr.update(visible=False), # video_files_row
237
  "", # video_count_text
238
  gr.update(visible=False), # webdataset_files_row
239
- "" # webdataset_count_text
 
240
  )
241
-
242
- def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str) -> Tuple[gr.update, str]:
243
- """Handle download of a group of files (videos or WebDatasets)"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  try:
245
  if not dataset_id:
246
- return gr.update(visible=False), "No dataset selected"
 
 
 
 
247
 
248
  logger.info(f"Starting download of {file_type} files from dataset: {dataset_id}")
249
 
250
- # Show loading indicator
251
- loading_msg = gr.update(
252
- value=f"## Downloading {file_type} files from {dataset_id}\n\nThis may take some time...",
253
- visible=True
254
- )
255
- status_msg = f"Downloading {file_type} files from {dataset_id}..."
256
 
257
- # Use the async version in a non-blocking way
258
- asyncio.create_task(self._download_file_group_bg(dataset_id, file_type, enable_splitting))
 
 
 
 
 
 
259
 
260
- return loading_msg, status_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  except Exception as e:
263
- error_msg = f"Error initiating download: {str(e)}"
264
  logger.error(error_msg, exc_info=True)
265
- return gr.update(visible=False), error_msg
266
-
267
- async def _download_file_group_bg(self, dataset_id: str, file_type: str, enable_splitting: bool):
268
- """Background task for group file download"""
269
- try:
270
- # This will execute in the background
271
- await self.app.importer.download_file_group(dataset_id, file_type, enable_splitting)
272
- except Exception as e:
273
- logger.error(f"Error in background file group download: {str(e)}", exc_info=True)
 
6
  import gradio as gr
7
  import logging
8
  import asyncio
9
+ import threading
10
  from pathlib import Path
11
  from typing import Dict, Any, List, Optional, Tuple
12
 
 
21
  super().__init__(app_state)
22
  self.id = "hub_tab"
23
  self.title = "Import from Hugging Face"
24
+ self.is_downloading = False
25
 
26
  def create(self, parent=None) -> gr.Tab:
27
  """Create the Hub tab UI components"""
 
35
 
36
  with gr.Row():
37
  self.components["dataset_search"] = gr.Textbox(
38
+ label="Search Hugging Face Datasets (eg. cakeify, disney, rickroll..)",
39
+ placeholder="Search for video datasets (eg. cakeify, disney, rickroll..)"
40
  )
41
 
42
  with gr.Row():
 
48
 
49
  with gr.Column(scale=3):
50
  self.components["dataset_results"] = gr.Dataframe(
51
+ headers=["Dataset ID"], # Simplified to show only dataset ID
52
  interactive=False,
53
  wrap=True,
54
  row_count=10,
 
60
  self.components["dataset_info"] = gr.Markdown("Select a dataset to see details")
61
  self.components["dataset_id"] = gr.State(value=None)
62
  self.components["file_type"] = gr.State(value=None)
63
+ self.components["download_in_progress"] = gr.State(value=False)
64
 
65
  # Files section that appears when a dataset is selected
66
  with gr.Column(visible=False) as files_section:
 
69
  gr.Markdown("## Files:")
70
 
71
  # Video files row (appears if videos are present)
72
+ with gr.Row() as video_files_row:
73
  self.components["video_files_row"] = video_files_row
74
 
75
+ self.components["video_count_text"] = gr.Markdown("Contains 0 video files")
 
76
 
77
+ self.components["download_videos_btn"] = gr.Button("Download", variant="primary")
 
78
 
79
  # WebDataset files row (appears if tar files are present)
80
+ with gr.Row() as webdataset_files_row:
81
  self.components["webdataset_files_row"] = webdataset_files_row
82
 
83
+ self.components["webdataset_count_text"] = gr.Markdown("Contains 0 WebDataset (.tar) files")
 
84
 
85
+ self.components["download_webdataset_btn"] = gr.Button("Download", variant="primary")
 
86
 
87
+ # Status indicator
88
+ self.components["status_output"] = gr.Markdown("")
89
 
90
  return tab
91
 
 
101
  ]
102
  )
103
 
104
+ # Dataset selection event
105
  self.components["dataset_results"].select(
106
  fn=self.display_dataset_info,
107
  outputs=[
 
111
  self.components["video_files_row"],
112
  self.components["video_count_text"],
113
  self.components["webdataset_files_row"],
114
+ self.components["webdataset_count_text"],
115
+ self.components["status_output"] # Reset status output
116
  ]
117
  )
118
 
 
128
  self.components["file_type"]
129
  ],
130
  outputs=[
131
+ self.components["status_output"],
132
+ self.components["import_status"],
133
+ self.components["download_videos_btn"],
134
+ self.components["download_webdataset_btn"],
135
+ self.components["download_in_progress"]
 
 
 
 
 
 
 
 
 
136
  ]
137
  )
138
 
 
148
  self.components["file_type"]
149
  ],
150
  outputs=[
151
+ self.components["status_output"],
152
+ self.components["import_status"],
153
+ self.components["download_videos_btn"],
154
+ self.components["download_webdataset_btn"],
155
+ self.components["download_in_progress"]
 
 
 
 
 
 
 
 
 
156
  ]
157
  )
158
 
 
168
  """Search datasets on the Hub matching the query"""
169
  try:
170
  logger.info(f"Searching for datasets with query: '{query}'")
171
+ results_full = self.app.importer.search_datasets(query)
172
+
173
+ # Extract just the first column (dataset IDs) for display
174
+ results = [[row[0]] for row in results_full]
175
+
176
  return results, gr.update(visible=True)
177
  except Exception as e:
178
  logger.error(f"Error searching datasets: {str(e)}", exc_info=True)
179
+ return [[f"Error: {str(e)}"]], gr.update(visible=True)
180
+
181
  def display_dataset_info(self, evt: gr.SelectData):
182
  """Display detailed information about the selected dataset"""
183
  try:
 
190
  gr.update(visible=False), # video_files_row
191
  "", # video_count_text
192
  gr.update(visible=False), # webdataset_files_row
193
+ "", # webdataset_count_text
194
+ "" # status_output
195
  )
196
 
197
+ # Extract dataset_id from the simplified format
198
  dataset_id = evt.value[0] if isinstance(evt.value, list) else evt.value
199
  logger.info(f"Getting dataset info for: {dataset_id}")
200
 
 
213
  gr.update(visible=video_count > 0), # video_files_row
214
  f"Contains {video_count} video file{'s' if video_count != 1 else ''}", # video_count_text
215
  gr.update(visible=webdataset_count > 0), # webdataset_files_row
216
+ f"Contains {webdataset_count} WebDataset (.tar) file{'s' if webdataset_count != 1 else ''}", # webdataset_count_text
217
+ "" # status_output
218
  )
219
  except Exception as e:
220
  logger.error(f"Error displaying dataset info: {str(e)}", exc_info=True)
 
225
  gr.update(visible=False), # video_files_row
226
  "", # video_count_text
227
  gr.update(visible=False), # webdataset_files_row
228
+ "", # webdataset_count_text
229
+ "" # status_output
230
  )
231
+
232
+ async def _download_with_progress(self, dataset_id, file_type, enable_splitting, progress_callback):
233
+ """Wrapper for download_file_group that integrates with progress tracking"""
234
+ try:
235
+ # Set up the progress callback adapter
236
+ def progress_adapter(progress_value, desc=None, total=None):
237
+ # For a progress bar, we need to convert the values to a 0-1 range
238
+ if isinstance(progress_value, (int, float)):
239
+ if total is not None and total > 0:
240
+ # If we have a total, calculate the fraction
241
+ fraction = min(1.0, progress_value / total)
242
+ else:
243
+ # Otherwise, just use the value directly (assumed to be 0-1)
244
+ fraction = min(1.0, progress_value)
245
+
246
+ # Update the progress with the calculated fraction
247
+ progress_callback(fraction, desc=desc)
248
+
249
+ # Call the actual download function with our adapter
250
+ result = await self.app.importer.download_file_group(
251
+ dataset_id,
252
+ file_type,
253
+ enable_splitting,
254
+ progress_callback=progress_adapter
255
+ )
256
+
257
+ return result
258
+
259
+ except Exception as e:
260
+ logger.error(f"Error in download with progress: {str(e)}", exc_info=True)
261
+ return f"Error: {str(e)}"
262
+
263
+ def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str, progress=gr.Progress()) -> Tuple:
264
+ """Handle download of a group of files (videos or WebDatasets) with progress tracking"""
265
  try:
266
  if not dataset_id:
267
+ return ("No dataset selected",
268
+ "No dataset selected",
269
+ gr.update(),
270
+ gr.update(),
271
+ False)
272
 
273
  logger.info(f"Starting download of {file_type} files from dataset: {dataset_id}")
274
 
275
+ # Initialize progress tracking
276
+ progress(0, desc=f"Starting download of {file_type} files from {dataset_id}")
277
+
278
+ # Disable download buttons during the process
279
+ videos_btn_update = gr.update(interactive=False)
280
+ webdataset_btn_update = gr.update(interactive=False)
281
 
282
+ # Run the download function with progress tracking
283
+ # We need to use asyncio.run to run the coroutine in a synchronous context
284
+ result = asyncio.run(self._download_with_progress(
285
+ dataset_id,
286
+ file_type,
287
+ enable_splitting,
288
+ progress
289
+ ))
290
 
291
+ # When download is complete, update the UI
292
+ progress(1.0, desc="Download complete!")
293
+
294
+ # Create a success message
295
+ success_msg = f"✅ Download complete! {result}"
296
+
297
+ # Update the UI components
298
+ return (
299
+ success_msg, # status_output - shows the successful result
300
+ result, # import_status
301
+ gr.update(interactive=True), # download_videos_btn
302
+ gr.update(interactive=True), # download_webdataset_btn
303
+ False # download_in_progress
304
+ )
305
 
306
  except Exception as e:
307
+ error_msg = f"Error downloading {file_type} files: {str(e)}"
308
  logger.error(error_msg, exc_info=True)
309
+ return (
310
+ f"❌ Error: {error_msg}", # status_output
311
+ error_msg, # import_status
312
+ gr.update(interactive=True), # download_videos_btn
313
+ gr.update(interactive=True), # download_webdataset_btn
314
+ False # download_in_progress
315
+ )
 
 
vms/tabs/import_tab/import_tab.py CHANGED
@@ -5,6 +5,7 @@ Parent import tab for Video Model Studio UI that contains sub-tabs
5
  import gradio as gr
6
  import logging
7
  import asyncio
 
8
  from pathlib import Path
9
  from typing import Dict, Any, List, Optional, Tuple
10
 
@@ -82,44 +83,97 @@ class ImportTab(BaseTab):
82
  self.youtube_tab.connect_events()
83
  self.hub_tab.connect_events()
84
 
85
- async def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
86
  """Handle successful import of files"""
87
  videos = self.app.tabs["split_tab"].list_unprocessed_videos()
88
 
89
  # If scene detection isn't already running and there are videos to process,
90
  # and auto-splitting is enabled, start the detection
91
  if videos and not self.app.splitter.is_processing() and enable_splitting:
92
- await self.app.tabs["split_tab"].start_scene_detection(enable_splitting)
 
93
  msg = "Starting automatic scene detection..."
94
  else:
95
  # Just copy files without splitting if auto-split disabled
96
- for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
97
- await self.app.splitter.process_video(video_file, enable_splitting=False)
98
  msg = "Copying videos without splitting..."
99
 
100
  self.app.tabs["caption_tab"].copy_files_to_training_dir(prompt_prefix)
101
 
102
- # Start auto-captioning if enabled, and handle async generator properly
103
  if enable_automatic_content_captioning:
104
- # Create a background task for captioning
105
- asyncio.create_task(self.app.tabs["caption_tab"]._process_caption_generator(
106
- DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
107
- prompt_prefix
108
- ))
109
 
110
- return {
111
- "tabs": gr.Tabs(selected="split_tab"),
112
- "video_list": videos,
113
- "detect_status": msg
114
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
117
  """Handle post-import updates including titles"""
118
- import_result = await self.on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix)
 
 
 
 
 
119
  titles = self.app.update_titles()
120
- return (
121
- import_result["tabs"],
122
- import_result["video_list"],
123
- import_result["detect_status"],
124
- *titles
125
- )
 
5
  import gradio as gr
6
  import logging
7
  import asyncio
8
+ import threading
9
  from pathlib import Path
10
  from typing import Dict, Any, List, Optional, Tuple
11
 
 
83
  self.youtube_tab.connect_events()
84
  self.hub_tab.connect_events()
85
 
86
+ def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
87
  """Handle successful import of files"""
88
  videos = self.app.tabs["split_tab"].list_unprocessed_videos()
89
 
90
  # If scene detection isn't already running and there are videos to process,
91
  # and auto-splitting is enabled, start the detection
92
  if videos and not self.app.splitter.is_processing() and enable_splitting:
93
+ # Start the scene detection in a separate thread
94
+ self._start_scene_detection_bg(enable_splitting)
95
  msg = "Starting automatic scene detection..."
96
  else:
97
  # Just copy files without splitting if auto-split disabled
98
+ self._start_copy_files_bg(enable_splitting)
 
99
  msg = "Copying videos without splitting..."
100
 
101
  self.app.tabs["caption_tab"].copy_files_to_training_dir(prompt_prefix)
102
 
103
+ # Start auto-captioning if enabled
104
  if enable_automatic_content_captioning:
105
+ self._start_captioning_bg(DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, prompt_prefix)
 
 
 
 
106
 
107
+ # Return the correct tuple of values as expected by the UI
108
+ return gr.update(selected="split_tab"), videos, msg
109
+
110
+ def _start_scene_detection_bg(self, enable_splitting):
111
+ """Start scene detection in a background thread"""
112
+ def run_async_in_thread():
113
+ loop = asyncio.new_event_loop()
114
+ asyncio.set_event_loop(loop)
115
+ try:
116
+ loop.run_until_complete(
117
+ self.app.tabs["split_tab"].start_scene_detection(enable_splitting)
118
+ )
119
+ except Exception as e:
120
+ logger.error(f"Error in background scene detection: {str(e)}", exc_info=True)
121
+ finally:
122
+ loop.close()
123
+
124
+ thread = threading.Thread(target=run_async_in_thread)
125
+ thread.daemon = True
126
+ thread.start()
127
+
128
+ def _start_copy_files_bg(self, enable_splitting):
129
+ """Start copying files in a background thread"""
130
+ def run_async_in_thread():
131
+ loop = asyncio.new_event_loop()
132
+ asyncio.set_event_loop(loop)
133
+ try:
134
+ async def copy_files():
135
+ for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
136
+ await self.app.splitter.process_video(video_file, enable_splitting=False)
137
+
138
+ loop.run_until_complete(copy_files())
139
+ except Exception as e:
140
+ logger.error(f"Error in background file copying: {str(e)}", exc_info=True)
141
+ finally:
142
+ loop.close()
143
+
144
+ thread = threading.Thread(target=run_async_in_thread)
145
+ thread.daemon = True
146
+ thread.start()
147
+
148
+ def _start_captioning_bg(self, instructions, prompt_prefix):
149
+ """Start captioning in a background thread"""
150
+ def run_async_in_thread():
151
+ loop = asyncio.new_event_loop()
152
+ asyncio.set_event_loop(loop)
153
+ try:
154
+ loop.run_until_complete(
155
+ self.app.tabs["caption_tab"]._process_caption_generator(
156
+ instructions, prompt_prefix
157
+ )
158
+ )
159
+ except Exception as e:
160
+ logger.error(f"Error in background captioning: {str(e)}", exc_info=True)
161
+ finally:
162
+ loop.close()
163
+
164
+ thread = threading.Thread(target=run_async_in_thread)
165
+ thread.daemon = True
166
+ thread.start()
167
 
168
  async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
169
  """Handle post-import updates including titles"""
170
+ # Call the non-async version since we need to return immediately for the UI
171
+ tabs, video_list, detect_status = self.on_import_success(
172
+ enable_splitting, enable_automatic_content_captioning, prompt_prefix
173
+ )
174
+
175
+ # Get updated titles
176
  titles = self.app.update_titles()
177
+
178
+ # Return all expected outputs
179
+ return tabs, video_list, detect_status, *titles
 
 
 
vms/ui/video_trainer_ui.py CHANGED
@@ -72,7 +72,33 @@ class VideoTrainerUI:
72
 
73
  # Log recovery status
74
  logger.info(f"Initialization complete. Recovery status: {self.recovery_status}")
 
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def create_ui(self):
77
  """Create the main Gradio UI"""
78
  with gr.Blocks(title="🎥 Video Model Studio") as app:
 
72
 
73
  # Log recovery status
74
  logger.info(f"Initialization complete. Recovery status: {self.recovery_status}")
75
+
76
+ def add_periodic_callback(self, callback_fn, interval=1.0):
77
+ """Add a periodic callback function to the UI
78
 
79
+ Args:
80
+ callback_fn: Function to call periodically
81
+ interval: Time in seconds between calls (default: 1.0)
82
+ """
83
+ try:
84
+ # Store a reference to the callback function
85
+ if not hasattr(self, "_periodic_callbacks"):
86
+ self._periodic_callbacks = []
87
+
88
+ self._periodic_callbacks.append(callback_fn)
89
+
90
+ # Add the callback to the Gradio app
91
+ self.app.add_callback(
92
+ interval, # Interval in seconds
93
+ callback_fn, # Function to call
94
+ inputs=None, # No inputs needed
95
+ outputs=list(self.components.values()) # All components as possible outputs
96
+ )
97
+
98
+ logger.info(f"Added periodic callback {callback_fn.__name__} with interval {interval}s")
99
+ except Exception as e:
100
+ logger.error(f"Error adding periodic callback: {e}", exc_info=True)
101
+
102
  def create_ui(self):
103
  """Create the main Gradio UI"""
104
  with gr.Blocks(title="🎥 Video Model Studio") as app: