Spaces:
Running
Running
File size: 30,083 Bytes
b613c3c aa1e877 b613c3c aa1e877 b613c3c 89bbef2 b613c3c 89bbef2 b613c3c 89bbef2 b613c3c aa1e877 246c64e aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 246c64e aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c aa1e877 b613c3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 |
"""
Hugging Face Hub dataset browser for Video Model Studio.
Handles searching, viewing, and downloading datasets from the Hub.
"""
import os
import shutil
import tempfile
import asyncio
import logging
import gradio as gr
from pathlib import Path
from typing import List, Dict, Optional, Tuple, Any, Union, Callable
from huggingface_hub import (
HfApi,
hf_hub_download,
snapshot_download,
list_datasets
)
from vms.config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
from vms.utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler
logger = logging.getLogger(__name__)
class HubDatasetBrowser:
"""Handles interactions with Hugging Face Hub datasets"""
def __init__(self, hf_api: HfApi):
"""Initialize with HfApi instance
Args:
hf_api: Hugging Face Hub API instance
"""
self.hf_api = hf_api
def search_datasets(self, query: str) -> List[List[str]]:
"""Search for datasets on the Hugging Face Hub
Args:
query: Search query string
Returns:
List of datasets matching the query [id, title, downloads]
Note: We still return all columns internally, but the UI will only display the first column
"""
try:
# Start with some filters to find video-related datasets
search_terms = query.strip() if query and query.strip() else "video"
logger.info(f"Searching datasets with query: '{search_terms}'")
# Fetch datasets that match the search
datasets = list(self.hf_api.list_datasets(
search=search_terms,
limit=50
))
# Format results for display
results = []
for ds in datasets:
# Extract relevant information
dataset_id = ds.id
# Safely get the title with fallbacks
card_data = getattr(ds, "card_data", None)
title = ""
if card_data is not None and isinstance(card_data, dict):
title = card_data.get("name", "")
if not title:
# Use the last part of the repo ID as a fallback
title = dataset_id.split("/")[-1]
# Safely get downloads
downloads = getattr(ds, "downloads", 0)
if downloads is None:
downloads = 0
results.append([dataset_id, title, downloads])
# Sort by downloads (most downloaded first)
results.sort(key=lambda x: x[2] if x[2] is not None else 0, reverse=True)
logger.info(f"Found {len(results)} datasets matching '{search_terms}'")
return results
except Exception as e:
logger.error(f"Error searching datasets: {str(e)}", exc_info=True)
return [[f"Error: {str(e)}", "", ""]]
def get_dataset_info(self, dataset_id: str) -> Tuple[str, Dict[str, int], Dict[str, List[str]]]:
"""Get detailed information about a dataset
Args:
dataset_id: The dataset ID to get information for
Returns:
Tuple of (markdown_info, file_counts, file_groups)
- markdown_info: Markdown formatted string with dataset information
- file_counts: Dictionary with counts of each file type
- file_groups: Dictionary with lists of filenames grouped by type
"""
try:
if not dataset_id:
logger.warning("No dataset ID provided to get_dataset_info")
return "No dataset selected", {}, {}
logger.info(f"Getting info for dataset: {dataset_id}")
# Get detailed information about the dataset
dataset_info = self.hf_api.dataset_info(dataset_id)
# Format the information for display
info_text = f"### {dataset_info.id}\n\n"
# Add description if available (with safer access)
card_data = getattr(dataset_info, "card_data", None)
description = ""
if card_data is not None and isinstance(card_data, dict):
description = card_data.get("description", "")
if description:
info_text += f"{description[:500]}{'...' if len(description) > 500 else ''}\n\n"
# Add basic stats (with safer access)
#downloads = getattr(dataset_info, 'downloads', None)
#info_text += f"## Downloads: {downloads if downloads is not None else 'N/A'}\n"
#last_modified = getattr(dataset_info, 'last_modified', None)
#info_text += f"## Last modified: {last_modified if last_modified is not None else 'N/A'}\n"
# Group files by type
file_groups = {
"video": [],
"webdataset": []
}
siblings = getattr(dataset_info, "siblings", None) or []
# Extract files by type
for s in siblings:
if not hasattr(s, 'rfilename'):
continue
filename = s.rfilename
if filename.lower().endswith((".mp4", ".webm")):
file_groups["video"].append(filename)
elif filename.lower().endswith(".tar"):
file_groups["webdataset"].append(filename)
# Create file counts dictionary
file_counts = {
"video": len(file_groups["video"]),
"webdataset": len(file_groups["webdataset"])
}
logger.info(f"Successfully retrieved info for dataset: {dataset_id}")
return info_text, file_counts, file_groups
except Exception as e:
logger.error(f"Error getting dataset info: {str(e)}", exc_info=True)
return f"Error loading dataset information: {str(e)}", {}, {}
async def download_file_group(
self,
dataset_id: str,
file_type: str,
enable_splitting: bool,
progress_callback: Optional[Callable] = None
) -> str:
"""Download all files of a specific type from the dataset
Args:
dataset_id: The dataset ID
file_type: Either "video" or "webdataset"
enable_splitting: Whether to enable automatic video splitting
progress_callback: Optional callback for progress updates
Returns:
Status message
"""
try:
# Get dataset info to retrieve file list
_, _, file_groups = self.get_dataset_info(dataset_id)
# Get the list of files for the specified type
files = file_groups.get(file_type, [])
if not files:
return f"No {file_type} files found in the dataset"
logger.info(f"Downloading {len(files)} {file_type} files from dataset {dataset_id}")
gr.Info(f"Starting download of {len(files)} {file_type} files from {dataset_id}")
# Initialize progress if callback provided
if progress_callback:
progress_callback(0, desc=f"Starting download of {len(files)} {file_type} files", total=len(files))
# Track counts for status message
video_count = 0
image_count = 0
# Create a temporary directory for downloads
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Process all files of the requested type
for i, filename in enumerate(files):
try:
# Update progress
if progress_callback:
progress_callback(
i,
desc=f"Downloading file {i+1}/{len(files)}: {Path(filename).name}",
total=len(files)
)
# Download the file
file_path = hf_hub_download(
repo_id=dataset_id,
filename=filename,
repo_type="dataset",
local_dir=temp_path
)
file_path = Path(file_path)
logger.info(f"Downloaded file to {file_path}")
#gr.Info(f"Downloaded {file_path.name} ({i+1}/{len(files)})")
# Process based on file type
if file_type == "video":
# Choose target directory based on auto-splitting setting
target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
target_path = target_dir / file_path.name
# Make sure filename is unique
counter = 1
while target_path.exists():
stem = Path(file_path.name).stem
if "___" in stem:
base_stem = stem.split("___")[0]
else:
base_stem = stem
target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}"
counter += 1
# Copy the video file
shutil.copy2(file_path, target_path)
logger.info(f"Processed video: {file_path.name} -> {target_path.name}")
# Try to download caption if it exists
try:
txt_filename = f"{Path(filename).stem}.txt"
for possible_path in [
Path(filename).with_suffix('.txt').as_posix(),
(Path(filename).parent / txt_filename).as_posix(),
]:
try:
txt_path = hf_hub_download(
repo_id=dataset_id,
filename=possible_path,
repo_type="dataset",
local_dir=temp_path
)
shutil.copy2(txt_path, target_path.with_suffix('.txt'))
logger.info(f"Copied caption for {file_path.name}")
break
except Exception:
# Caption file doesn't exist at this path, try next
pass
except Exception as e:
logger.warning(f"Error trying to download caption: {e}")
video_count += 1
elif file_type == "webdataset":
# Process the WebDataset archive
try:
logger.info(f"Processing WebDataset file: {file_path}")
vid_count, img_count = webdataset_handler.process_webdataset_shard(
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
)
video_count += vid_count
image_count += img_count
except Exception as e:
logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
except Exception as e:
logger.warning(f"Error processing file {filename}: {e}")
# Update progress to complete
if progress_callback:
progress_callback(len(files), desc="Download complete", total=len(files))
# Generate status message
if file_type == "video":
status_msg = f"Successfully imported {video_count} videos from dataset {dataset_id}"
elif file_type == "webdataset":
parts = []
if video_count > 0:
parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
if image_count > 0:
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
if parts:
status_msg = f"Successfully imported {' and '.join(parts)} from WebDataset archives"
else:
status_msg = f"No media was found in the WebDataset archives"
else:
status_msg = f"Unknown file type: {file_type}"
# Final notification
logger.info(f"✅ Download complete! {status_msg}")
# This info message will appear as a toast notification
gr.Info(f"✅ Download complete! {status_msg}")
return status_msg
except Exception as e:
error_msg = f"Error downloading {file_type} files: {str(e)}"
logger.error(error_msg, exc_info=True)
gr.Error(error_msg)
return error_msg
async def download_dataset(
self,
dataset_id: str,
enable_splitting: bool,
progress_callback: Optional[Callable] = None
) -> Tuple[str, str]:
"""Download a dataset and process its video/image content
Args:
dataset_id: The dataset ID to download
enable_splitting: Whether to enable automatic video splitting
progress_callback: Optional callback for progress tracking
Returns:
Tuple of (loading_msg, status_msg)
"""
if not dataset_id:
logger.warning("No dataset ID provided for download")
return "No dataset selected", "Please select a dataset first"
try:
logger.info(f"Starting download of dataset: {dataset_id}")
loading_msg = f"## Downloading dataset: {dataset_id}\n\nThis may take some time depending on the dataset size..."
status_msg = f"Downloading dataset: {dataset_id}..."
# Get dataset info to check for available files
dataset_info = self.hf_api.dataset_info(dataset_id)
# Check if there are video files or WebDataset files
video_files = []
tar_files = []
siblings = getattr(dataset_info, "siblings", None) or []
if siblings:
video_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith((".mp4", ".webm"))]
tar_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith(".tar")]
# Initialize progress tracking
total_files = len(video_files) + len(tar_files)
if progress_callback:
progress_callback(0, desc=f"Starting download of dataset: {dataset_id}", total=total_files)
# Create a temporary directory for downloads
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
files_processed = 0
# If we have video files, download them individually
if video_files:
loading_msg = f"{loading_msg}\n\nDownloading {len(video_files)} video files..."
logger.info(f"Downloading {len(video_files)} video files from {dataset_id}")
for i, video_file in enumerate(video_files):
# Update progress
if progress_callback:
progress_callback(
files_processed,
desc=f"Downloading video {i+1}/{len(video_files)}: {Path(video_file).name}",
total=total_files
)
# Download the video file
try:
file_path = hf_hub_download(
repo_id=dataset_id,
filename=video_file,
repo_type="dataset",
local_dir=temp_path
)
# Look for associated caption file
txt_filename = f"{Path(video_file).stem}.txt"
txt_path = None
for possible_path in [
Path(video_file).with_suffix('.txt').as_posix(),
(Path(video_file).parent / txt_filename).as_posix(),
]:
try:
txt_path = hf_hub_download(
repo_id=dataset_id,
filename=possible_path,
repo_type="dataset",
local_dir=temp_path
)
logger.info(f"Found caption file for {video_file}: {possible_path}")
break
except Exception as e:
# Caption file doesn't exist at this path, try next
logger.debug(f"No caption at {possible_path}: {str(e)}")
pass
status_msg = f"Downloaded video {i+1}/{len(video_files)} from {dataset_id}"
logger.info(status_msg)
files_processed += 1
except Exception as e:
logger.warning(f"Error downloading {video_file}: {e}")
# If we have tar files, download them
if tar_files:
loading_msg = f"{loading_msg}\n\nDownloading {len(tar_files)} WebDataset files..."
logger.info(f"Downloading {len(tar_files)} WebDataset files from {dataset_id}")
for i, tar_file in enumerate(tar_files):
# Update progress
if progress_callback:
progress_callback(
files_processed,
desc=f"Downloading WebDataset {i+1}/{len(tar_files)}: {Path(tar_file).name}",
total=total_files
)
try:
file_path = hf_hub_download(
repo_id=dataset_id,
filename=tar_file,
repo_type="dataset",
local_dir=temp_path
)
status_msg = f"Downloaded WebDataset {i+1}/{len(tar_files)} from {dataset_id}"
logger.info(status_msg)
files_processed += 1
except Exception as e:
logger.warning(f"Error downloading {tar_file}: {e}")
# If no specific files were found, try downloading the entire repo
if not video_files and not tar_files:
loading_msg = f"{loading_msg}\n\nDownloading entire dataset repository..."
logger.info(f"No specific media files found, downloading entire repository for {dataset_id}")
if progress_callback:
progress_callback(0, desc=f"Downloading entire repository for {dataset_id}", total=1)
try:
snapshot_download(
repo_id=dataset_id,
repo_type="dataset",
local_dir=temp_path
)
status_msg = f"Downloaded entire repository for {dataset_id}"
logger.info(status_msg)
if progress_callback:
progress_callback(1, desc="Repository download complete", total=1)
except Exception as e:
logger.error(f"Error downloading dataset snapshot: {e}", exc_info=True)
return loading_msg, f"Error downloading dataset: {str(e)}"
# Process the downloaded files
loading_msg = f"{loading_msg}\n\nProcessing downloaded files..."
logger.info(f"Processing downloaded files from {dataset_id}")
if progress_callback:
progress_callback(0, desc="Processing downloaded files", total=100)
# Count imported files
video_count = 0
image_count = 0
tar_count = 0
# Process function for the event loop
async def process_files():
nonlocal video_count, image_count, tar_count
# Get total number of files to process
file_count = 0
for root, _, files in os.walk(temp_path):
file_count += len(files)
processed = 0
# Process all files in the temp directory
for root, _, files in os.walk(temp_path):
for file in files:
file_path = Path(root) / file
# Update progress (every 5 files to avoid too many updates)
if progress_callback and processed % 5 == 0:
if file_count > 0:
progress_percent = int((processed / file_count) * 100)
progress_callback(
progress_percent,
desc=f"Processing files: {processed}/{file_count}",
total=100
)
# Process videos
if file.lower().endswith((".mp4", ".webm")):
# Choose target path based on auto-splitting setting
target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
target_path = target_dir / file_path.name
# Make sure filename is unique
counter = 1
while target_path.exists():
stem = Path(file_path.name).stem
if "___" in stem:
base_stem = stem.split("___")[0]
else:
base_stem = stem
target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}"
counter += 1
# Copy the video file
shutil.copy2(file_path, target_path)
logger.info(f"Processed video from dataset: {file_path.name} -> {target_path.name}")
# Copy associated caption file if it exists
txt_path = file_path.with_suffix('.txt')
if txt_path.exists():
shutil.copy2(txt_path, target_path.with_suffix('.txt'))
logger.info(f"Copied caption for {file_path.name}")
video_count += 1
# Process images
elif is_image_file(file_path):
target_path = STAGING_PATH / f"{file_path.stem}.{NORMALIZE_IMAGES_TO}"
counter = 1
while target_path.exists():
target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
counter += 1
if normalize_image(file_path, target_path):
logger.info(f"Processed image from dataset: {file_path.name} -> {target_path.name}")
# Copy caption if available
txt_path = file_path.with_suffix('.txt')
if txt_path.exists():
caption = txt_path.read_text()
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
target_path.with_suffix('.txt').write_text(caption)
logger.info(f"Processed caption for {file_path.name}")
image_count += 1
# Process WebDataset files
elif file.lower().endswith(".tar"):
# Process the WebDataset archive
try:
logger.info(f"Processing WebDataset file from dataset: {file}")
vid_count, img_count = webdataset_handler.process_webdataset_shard(
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
)
tar_count += 1
video_count += vid_count
image_count += img_count
logger.info(f"Extracted {vid_count} videos and {img_count} images from {file}")
except Exception as e:
logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
processed += 1
# Run the processing asynchronously
await process_files()
# Update progress to complete
if progress_callback:
progress_callback(100, desc="Processing complete", total=100)
# Generate final status message
parts = []
if video_count > 0:
parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
if image_count > 0:
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
if tar_count > 0:
parts.append(f"{tar_count} WebDataset archive{'s' if tar_count != 1 else ''}")
if parts:
status = f"Successfully imported {', '.join(parts)} from dataset {dataset_id}"
loading_msg = f"{loading_msg}\n\n✅ Success! {status}"
logger.info(status)
else:
status = f"No supported media files found in dataset {dataset_id}"
loading_msg = f"{loading_msg}\n\n⚠️ {status}"
logger.warning(status)
gr.Info(status)
return loading_msg, status
except Exception as e:
error_msg = f"Error downloading dataset {dataset_id}: {str(e)}"
logger.error(error_msg, exc_info=True)
return f"Error: {error_msg}", error_msg |