jbilcke-hf's picture
jbilcke-hf HF Staff
fix issue with scene splitting
246c64e
"""
Hugging Face Hub dataset browser for Video Model Studio.
Handles searching, viewing, and downloading datasets from the Hub.
"""
import os
import shutil
import tempfile
import asyncio
import logging
import gradio as gr
from pathlib import Path
from typing import List, Dict, Optional, Tuple, Any, Union, Callable
from huggingface_hub import (
HfApi,
hf_hub_download,
snapshot_download,
list_datasets
)
from vms.config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
from vms.utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler
logger = logging.getLogger(__name__)
class HubDatasetBrowser:
"""Handles interactions with Hugging Face Hub datasets"""
def __init__(self, hf_api: HfApi):
"""Initialize with HfApi instance
Args:
hf_api: Hugging Face Hub API instance
"""
self.hf_api = hf_api
def search_datasets(self, query: str) -> List[List[str]]:
"""Search for datasets on the Hugging Face Hub
Args:
query: Search query string
Returns:
List of datasets matching the query [id, title, downloads]
Note: We still return all columns internally, but the UI will only display the first column
"""
try:
# Start with some filters to find video-related datasets
search_terms = query.strip() if query and query.strip() else "video"
logger.info(f"Searching datasets with query: '{search_terms}'")
# Fetch datasets that match the search
datasets = list(self.hf_api.list_datasets(
search=search_terms,
limit=50
))
# Format results for display
results = []
for ds in datasets:
# Extract relevant information
dataset_id = ds.id
# Safely get the title with fallbacks
card_data = getattr(ds, "card_data", None)
title = ""
if card_data is not None and isinstance(card_data, dict):
title = card_data.get("name", "")
if not title:
# Use the last part of the repo ID as a fallback
title = dataset_id.split("/")[-1]
# Safely get downloads
downloads = getattr(ds, "downloads", 0)
if downloads is None:
downloads = 0
results.append([dataset_id, title, downloads])
# Sort by downloads (most downloaded first)
results.sort(key=lambda x: x[2] if x[2] is not None else 0, reverse=True)
logger.info(f"Found {len(results)} datasets matching '{search_terms}'")
return results
except Exception as e:
logger.error(f"Error searching datasets: {str(e)}", exc_info=True)
return [[f"Error: {str(e)}", "", ""]]
def get_dataset_info(self, dataset_id: str) -> Tuple[str, Dict[str, int], Dict[str, List[str]]]:
"""Get detailed information about a dataset
Args:
dataset_id: The dataset ID to get information for
Returns:
Tuple of (markdown_info, file_counts, file_groups)
- markdown_info: Markdown formatted string with dataset information
- file_counts: Dictionary with counts of each file type
- file_groups: Dictionary with lists of filenames grouped by type
"""
try:
if not dataset_id:
logger.warning("No dataset ID provided to get_dataset_info")
return "No dataset selected", {}, {}
logger.info(f"Getting info for dataset: {dataset_id}")
# Get detailed information about the dataset
dataset_info = self.hf_api.dataset_info(dataset_id)
# Format the information for display
info_text = f"### {dataset_info.id}\n\n"
# Add description if available (with safer access)
card_data = getattr(dataset_info, "card_data", None)
description = ""
if card_data is not None and isinstance(card_data, dict):
description = card_data.get("description", "")
if description:
info_text += f"{description[:500]}{'...' if len(description) > 500 else ''}\n\n"
# Add basic stats (with safer access)
#downloads = getattr(dataset_info, 'downloads', None)
#info_text += f"## Downloads: {downloads if downloads is not None else 'N/A'}\n"
#last_modified = getattr(dataset_info, 'last_modified', None)
#info_text += f"## Last modified: {last_modified if last_modified is not None else 'N/A'}\n"
# Group files by type
file_groups = {
"video": [],
"webdataset": []
}
siblings = getattr(dataset_info, "siblings", None) or []
# Extract files by type
for s in siblings:
if not hasattr(s, 'rfilename'):
continue
filename = s.rfilename
if filename.lower().endswith((".mp4", ".webm")):
file_groups["video"].append(filename)
elif filename.lower().endswith(".tar"):
file_groups["webdataset"].append(filename)
# Create file counts dictionary
file_counts = {
"video": len(file_groups["video"]),
"webdataset": len(file_groups["webdataset"])
}
logger.info(f"Successfully retrieved info for dataset: {dataset_id}")
return info_text, file_counts, file_groups
except Exception as e:
logger.error(f"Error getting dataset info: {str(e)}", exc_info=True)
return f"Error loading dataset information: {str(e)}", {}, {}
async def download_file_group(
self,
dataset_id: str,
file_type: str,
enable_splitting: bool,
progress_callback: Optional[Callable] = None
) -> str:
"""Download all files of a specific type from the dataset
Args:
dataset_id: The dataset ID
file_type: Either "video" or "webdataset"
enable_splitting: Whether to enable automatic video splitting
progress_callback: Optional callback for progress updates
Returns:
Status message
"""
try:
# Get dataset info to retrieve file list
_, _, file_groups = self.get_dataset_info(dataset_id)
# Get the list of files for the specified type
files = file_groups.get(file_type, [])
if not files:
return f"No {file_type} files found in the dataset"
logger.info(f"Downloading {len(files)} {file_type} files from dataset {dataset_id}")
gr.Info(f"Starting download of {len(files)} {file_type} files from {dataset_id}")
# Initialize progress if callback provided
if progress_callback:
progress_callback(0, desc=f"Starting download of {len(files)} {file_type} files", total=len(files))
# Track counts for status message
video_count = 0
image_count = 0
# Create a temporary directory for downloads
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Process all files of the requested type
for i, filename in enumerate(files):
try:
# Update progress
if progress_callback:
progress_callback(
i,
desc=f"Downloading file {i+1}/{len(files)}: {Path(filename).name}",
total=len(files)
)
# Download the file
file_path = hf_hub_download(
repo_id=dataset_id,
filename=filename,
repo_type="dataset",
local_dir=temp_path
)
file_path = Path(file_path)
logger.info(f"Downloaded file to {file_path}")
#gr.Info(f"Downloaded {file_path.name} ({i+1}/{len(files)})")
# Process based on file type
if file_type == "video":
# Choose target directory based on auto-splitting setting
target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
target_path = target_dir / file_path.name
# Make sure filename is unique
counter = 1
while target_path.exists():
stem = Path(file_path.name).stem
if "___" in stem:
base_stem = stem.split("___")[0]
else:
base_stem = stem
target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}"
counter += 1
# Copy the video file
shutil.copy2(file_path, target_path)
logger.info(f"Processed video: {file_path.name} -> {target_path.name}")
# Try to download caption if it exists
try:
txt_filename = f"{Path(filename).stem}.txt"
for possible_path in [
Path(filename).with_suffix('.txt').as_posix(),
(Path(filename).parent / txt_filename).as_posix(),
]:
try:
txt_path = hf_hub_download(
repo_id=dataset_id,
filename=possible_path,
repo_type="dataset",
local_dir=temp_path
)
shutil.copy2(txt_path, target_path.with_suffix('.txt'))
logger.info(f"Copied caption for {file_path.name}")
break
except Exception:
# Caption file doesn't exist at this path, try next
pass
except Exception as e:
logger.warning(f"Error trying to download caption: {e}")
video_count += 1
elif file_type == "webdataset":
# Process the WebDataset archive
try:
logger.info(f"Processing WebDataset file: {file_path}")
vid_count, img_count = webdataset_handler.process_webdataset_shard(
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
)
video_count += vid_count
image_count += img_count
except Exception as e:
logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
except Exception as e:
logger.warning(f"Error processing file {filename}: {e}")
# Update progress to complete
if progress_callback:
progress_callback(len(files), desc="Download complete", total=len(files))
# Generate status message
if file_type == "video":
status_msg = f"Successfully imported {video_count} videos from dataset {dataset_id}"
elif file_type == "webdataset":
parts = []
if video_count > 0:
parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
if image_count > 0:
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
if parts:
status_msg = f"Successfully imported {' and '.join(parts)} from WebDataset archives"
else:
status_msg = f"No media was found in the WebDataset archives"
else:
status_msg = f"Unknown file type: {file_type}"
# Final notification
logger.info(f"✅ Download complete! {status_msg}")
# This info message will appear as a toast notification
gr.Info(f"✅ Download complete! {status_msg}")
return status_msg
except Exception as e:
error_msg = f"Error downloading {file_type} files: {str(e)}"
logger.error(error_msg, exc_info=True)
gr.Error(error_msg)
return error_msg
async def download_dataset(
self,
dataset_id: str,
enable_splitting: bool,
progress_callback: Optional[Callable] = None
) -> Tuple[str, str]:
"""Download a dataset and process its video/image content
Args:
dataset_id: The dataset ID to download
enable_splitting: Whether to enable automatic video splitting
progress_callback: Optional callback for progress tracking
Returns:
Tuple of (loading_msg, status_msg)
"""
if not dataset_id:
logger.warning("No dataset ID provided for download")
return "No dataset selected", "Please select a dataset first"
try:
logger.info(f"Starting download of dataset: {dataset_id}")
loading_msg = f"## Downloading dataset: {dataset_id}\n\nThis may take some time depending on the dataset size..."
status_msg = f"Downloading dataset: {dataset_id}..."
# Get dataset info to check for available files
dataset_info = self.hf_api.dataset_info(dataset_id)
# Check if there are video files or WebDataset files
video_files = []
tar_files = []
siblings = getattr(dataset_info, "siblings", None) or []
if siblings:
video_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith((".mp4", ".webm"))]
tar_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith(".tar")]
# Initialize progress tracking
total_files = len(video_files) + len(tar_files)
if progress_callback:
progress_callback(0, desc=f"Starting download of dataset: {dataset_id}", total=total_files)
# Create a temporary directory for downloads
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
files_processed = 0
# If we have video files, download them individually
if video_files:
loading_msg = f"{loading_msg}\n\nDownloading {len(video_files)} video files..."
logger.info(f"Downloading {len(video_files)} video files from {dataset_id}")
for i, video_file in enumerate(video_files):
# Update progress
if progress_callback:
progress_callback(
files_processed,
desc=f"Downloading video {i+1}/{len(video_files)}: {Path(video_file).name}",
total=total_files
)
# Download the video file
try:
file_path = hf_hub_download(
repo_id=dataset_id,
filename=video_file,
repo_type="dataset",
local_dir=temp_path
)
# Look for associated caption file
txt_filename = f"{Path(video_file).stem}.txt"
txt_path = None
for possible_path in [
Path(video_file).with_suffix('.txt').as_posix(),
(Path(video_file).parent / txt_filename).as_posix(),
]:
try:
txt_path = hf_hub_download(
repo_id=dataset_id,
filename=possible_path,
repo_type="dataset",
local_dir=temp_path
)
logger.info(f"Found caption file for {video_file}: {possible_path}")
break
except Exception as e:
# Caption file doesn't exist at this path, try next
logger.debug(f"No caption at {possible_path}: {str(e)}")
pass
status_msg = f"Downloaded video {i+1}/{len(video_files)} from {dataset_id}"
logger.info(status_msg)
files_processed += 1
except Exception as e:
logger.warning(f"Error downloading {video_file}: {e}")
# If we have tar files, download them
if tar_files:
loading_msg = f"{loading_msg}\n\nDownloading {len(tar_files)} WebDataset files..."
logger.info(f"Downloading {len(tar_files)} WebDataset files from {dataset_id}")
for i, tar_file in enumerate(tar_files):
# Update progress
if progress_callback:
progress_callback(
files_processed,
desc=f"Downloading WebDataset {i+1}/{len(tar_files)}: {Path(tar_file).name}",
total=total_files
)
try:
file_path = hf_hub_download(
repo_id=dataset_id,
filename=tar_file,
repo_type="dataset",
local_dir=temp_path
)
status_msg = f"Downloaded WebDataset {i+1}/{len(tar_files)} from {dataset_id}"
logger.info(status_msg)
files_processed += 1
except Exception as e:
logger.warning(f"Error downloading {tar_file}: {e}")
# If no specific files were found, try downloading the entire repo
if not video_files and not tar_files:
loading_msg = f"{loading_msg}\n\nDownloading entire dataset repository..."
logger.info(f"No specific media files found, downloading entire repository for {dataset_id}")
if progress_callback:
progress_callback(0, desc=f"Downloading entire repository for {dataset_id}", total=1)
try:
snapshot_download(
repo_id=dataset_id,
repo_type="dataset",
local_dir=temp_path
)
status_msg = f"Downloaded entire repository for {dataset_id}"
logger.info(status_msg)
if progress_callback:
progress_callback(1, desc="Repository download complete", total=1)
except Exception as e:
logger.error(f"Error downloading dataset snapshot: {e}", exc_info=True)
return loading_msg, f"Error downloading dataset: {str(e)}"
# Process the downloaded files
loading_msg = f"{loading_msg}\n\nProcessing downloaded files..."
logger.info(f"Processing downloaded files from {dataset_id}")
if progress_callback:
progress_callback(0, desc="Processing downloaded files", total=100)
# Count imported files
video_count = 0
image_count = 0
tar_count = 0
# Process function for the event loop
async def process_files():
nonlocal video_count, image_count, tar_count
# Get total number of files to process
file_count = 0
for root, _, files in os.walk(temp_path):
file_count += len(files)
processed = 0
# Process all files in the temp directory
for root, _, files in os.walk(temp_path):
for file in files:
file_path = Path(root) / file
# Update progress (every 5 files to avoid too many updates)
if progress_callback and processed % 5 == 0:
if file_count > 0:
progress_percent = int((processed / file_count) * 100)
progress_callback(
progress_percent,
desc=f"Processing files: {processed}/{file_count}",
total=100
)
# Process videos
if file.lower().endswith((".mp4", ".webm")):
# Choose target path based on auto-splitting setting
target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
target_path = target_dir / file_path.name
# Make sure filename is unique
counter = 1
while target_path.exists():
stem = Path(file_path.name).stem
if "___" in stem:
base_stem = stem.split("___")[0]
else:
base_stem = stem
target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}"
counter += 1
# Copy the video file
shutil.copy2(file_path, target_path)
logger.info(f"Processed video from dataset: {file_path.name} -> {target_path.name}")
# Copy associated caption file if it exists
txt_path = file_path.with_suffix('.txt')
if txt_path.exists():
shutil.copy2(txt_path, target_path.with_suffix('.txt'))
logger.info(f"Copied caption for {file_path.name}")
video_count += 1
# Process images
elif is_image_file(file_path):
target_path = STAGING_PATH / f"{file_path.stem}.{NORMALIZE_IMAGES_TO}"
counter = 1
while target_path.exists():
target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
counter += 1
if normalize_image(file_path, target_path):
logger.info(f"Processed image from dataset: {file_path.name} -> {target_path.name}")
# Copy caption if available
txt_path = file_path.with_suffix('.txt')
if txt_path.exists():
caption = txt_path.read_text()
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
target_path.with_suffix('.txt').write_text(caption)
logger.info(f"Processed caption for {file_path.name}")
image_count += 1
# Process WebDataset files
elif file.lower().endswith(".tar"):
# Process the WebDataset archive
try:
logger.info(f"Processing WebDataset file from dataset: {file}")
vid_count, img_count = webdataset_handler.process_webdataset_shard(
file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
)
tar_count += 1
video_count += vid_count
image_count += img_count
logger.info(f"Extracted {vid_count} videos and {img_count} images from {file}")
except Exception as e:
logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
processed += 1
# Run the processing asynchronously
await process_files()
# Update progress to complete
if progress_callback:
progress_callback(100, desc="Processing complete", total=100)
# Generate final status message
parts = []
if video_count > 0:
parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
if image_count > 0:
parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
if tar_count > 0:
parts.append(f"{tar_count} WebDataset archive{'s' if tar_count != 1 else ''}")
if parts:
status = f"Successfully imported {', '.join(parts)} from dataset {dataset_id}"
loading_msg = f"{loading_msg}\n\n✅ Success! {status}"
logger.info(status)
else:
status = f"No supported media files found in dataset {dataset_id}"
loading_msg = f"{loading_msg}\n\n⚠️ {status}"
logger.warning(status)
gr.Info(status)
return loading_msg, status
except Exception as e:
error_msg = f"Error downloading dataset {dataset_id}: {str(e)}"
logger.error(error_msg, exc_info=True)
return f"Error: {error_msg}", error_msg