""" Hugging Face Hub dataset browser for Video Model Studio. Handles searching, viewing, and downloading datasets from the Hub. """ import os import shutil import tempfile import asyncio import logging import gradio as gr from pathlib import Path from typing import List, Dict, Optional, Tuple, Any, Union, Callable from huggingface_hub import ( HfApi, hf_hub_download, snapshot_download, list_datasets ) from vms.config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX from vms.utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler logger = logging.getLogger(__name__) class HubDatasetBrowser: """Handles interactions with Hugging Face Hub datasets""" def __init__(self, hf_api: HfApi): """Initialize with HfApi instance Args: hf_api: Hugging Face Hub API instance """ self.hf_api = hf_api def search_datasets(self, query: str) -> List[List[str]]: """Search for datasets on the Hugging Face Hub Args: query: Search query string Returns: List of datasets matching the query [id, title, downloads] Note: We still return all columns internally, but the UI will only display the first column """ try: # Start with some filters to find video-related datasets search_terms = query.strip() if query and query.strip() else "video" logger.info(f"Searching datasets with query: '{search_terms}'") # Fetch datasets that match the search datasets = list(self.hf_api.list_datasets( search=search_terms, limit=50 )) # Format results for display results = [] for ds in datasets: # Extract relevant information dataset_id = ds.id # Safely get the title with fallbacks card_data = getattr(ds, "card_data", None) title = "" if card_data is not None and isinstance(card_data, dict): title = card_data.get("name", "") if not title: # Use the last part of the repo ID as a fallback title = dataset_id.split("/")[-1] # Safely get downloads downloads = getattr(ds, "downloads", 0) if downloads is None: downloads = 0 results.append([dataset_id, title, downloads]) # Sort by downloads (most downloaded first) results.sort(key=lambda x: x[2] if x[2] is not None else 0, reverse=True) logger.info(f"Found {len(results)} datasets matching '{search_terms}'") return results except Exception as e: logger.error(f"Error searching datasets: {str(e)}", exc_info=True) return [[f"Error: {str(e)}", "", ""]] def get_dataset_info(self, dataset_id: str) -> Tuple[str, Dict[str, int], Dict[str, List[str]]]: """Get detailed information about a dataset Args: dataset_id: The dataset ID to get information for Returns: Tuple of (markdown_info, file_counts, file_groups) - markdown_info: Markdown formatted string with dataset information - file_counts: Dictionary with counts of each file type - file_groups: Dictionary with lists of filenames grouped by type """ try: if not dataset_id: logger.warning("No dataset ID provided to get_dataset_info") return "No dataset selected", {}, {} logger.info(f"Getting info for dataset: {dataset_id}") # Get detailed information about the dataset dataset_info = self.hf_api.dataset_info(dataset_id) # Format the information for display info_text = f"### {dataset_info.id}\n\n" # Add description if available (with safer access) card_data = getattr(dataset_info, "card_data", None) description = "" if card_data is not None and isinstance(card_data, dict): description = card_data.get("description", "") if description: info_text += f"{description[:500]}{'...' if len(description) > 500 else ''}\n\n" # Add basic stats (with safer access) #downloads = getattr(dataset_info, 'downloads', None) #info_text += f"## Downloads: {downloads if downloads is not None else 'N/A'}\n" #last_modified = getattr(dataset_info, 'last_modified', None) #info_text += f"## Last modified: {last_modified if last_modified is not None else 'N/A'}\n" # Group files by type file_groups = { "video": [], "webdataset": [] } siblings = getattr(dataset_info, "siblings", None) or [] # Extract files by type for s in siblings: if not hasattr(s, 'rfilename'): continue filename = s.rfilename if filename.lower().endswith((".mp4", ".webm")): file_groups["video"].append(filename) elif filename.lower().endswith(".tar"): file_groups["webdataset"].append(filename) # Create file counts dictionary file_counts = { "video": len(file_groups["video"]), "webdataset": len(file_groups["webdataset"]) } logger.info(f"Successfully retrieved info for dataset: {dataset_id}") return info_text, file_counts, file_groups except Exception as e: logger.error(f"Error getting dataset info: {str(e)}", exc_info=True) return f"Error loading dataset information: {str(e)}", {}, {} async def download_file_group( self, dataset_id: str, file_type: str, enable_splitting: bool, progress_callback: Optional[Callable] = None ) -> str: """Download all files of a specific type from the dataset Args: dataset_id: The dataset ID file_type: Either "video" or "webdataset" enable_splitting: Whether to enable automatic video splitting progress_callback: Optional callback for progress updates Returns: Status message """ try: # Get dataset info to retrieve file list _, _, file_groups = self.get_dataset_info(dataset_id) # Get the list of files for the specified type files = file_groups.get(file_type, []) if not files: return f"No {file_type} files found in the dataset" logger.info(f"Downloading {len(files)} {file_type} files from dataset {dataset_id}") gr.Info(f"Starting download of {len(files)} {file_type} files from {dataset_id}") # Initialize progress if callback provided if progress_callback: progress_callback(0, desc=f"Starting download of {len(files)} {file_type} files", total=len(files)) # Track counts for status message video_count = 0 image_count = 0 # Create a temporary directory for downloads with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) # Process all files of the requested type for i, filename in enumerate(files): try: # Update progress if progress_callback: progress_callback( i, desc=f"Downloading file {i+1}/{len(files)}: {Path(filename).name}", total=len(files) ) # Download the file file_path = hf_hub_download( repo_id=dataset_id, filename=filename, repo_type="dataset", local_dir=temp_path ) file_path = Path(file_path) logger.info(f"Downloaded file to {file_path}") #gr.Info(f"Downloaded {file_path.name} ({i+1}/{len(files)})") # Process based on file type if file_type == "video": # Choose target directory based on auto-splitting setting target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH target_path = target_dir / file_path.name # Make sure filename is unique counter = 1 while target_path.exists(): stem = Path(file_path.name).stem if "___" in stem: base_stem = stem.split("___")[0] else: base_stem = stem target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}" counter += 1 # Copy the video file shutil.copy2(file_path, target_path) logger.info(f"Processed video: {file_path.name} -> {target_path.name}") # Try to download caption if it exists try: txt_filename = f"{Path(filename).stem}.txt" for possible_path in [ Path(filename).with_suffix('.txt').as_posix(), (Path(filename).parent / txt_filename).as_posix(), ]: try: txt_path = hf_hub_download( repo_id=dataset_id, filename=possible_path, repo_type="dataset", local_dir=temp_path ) shutil.copy2(txt_path, target_path.with_suffix('.txt')) logger.info(f"Copied caption for {file_path.name}") break except Exception: # Caption file doesn't exist at this path, try next pass except Exception as e: logger.warning(f"Error trying to download caption: {e}") video_count += 1 elif file_type == "webdataset": # Process the WebDataset archive try: logger.info(f"Processing WebDataset file: {file_path}") vid_count, img_count = webdataset_handler.process_webdataset_shard( file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH ) video_count += vid_count image_count += img_count except Exception as e: logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True) except Exception as e: logger.warning(f"Error processing file {filename}: {e}") # Update progress to complete if progress_callback: progress_callback(len(files), desc="Download complete", total=len(files)) # Generate status message if file_type == "video": status_msg = f"Successfully imported {video_count} videos from dataset {dataset_id}" elif file_type == "webdataset": parts = [] if video_count > 0: parts.append(f"{video_count} video{'s' if video_count != 1 else ''}") if image_count > 0: parts.append(f"{image_count} image{'s' if image_count != 1 else ''}") if parts: status_msg = f"Successfully imported {' and '.join(parts)} from WebDataset archives" else: status_msg = f"No media was found in the WebDataset archives" else: status_msg = f"Unknown file type: {file_type}" # Final notification logger.info(f"✅ Download complete! {status_msg}") # This info message will appear as a toast notification gr.Info(f"✅ Download complete! {status_msg}") return status_msg except Exception as e: error_msg = f"Error downloading {file_type} files: {str(e)}" logger.error(error_msg, exc_info=True) gr.Error(error_msg) return error_msg async def download_dataset( self, dataset_id: str, enable_splitting: bool, progress_callback: Optional[Callable] = None ) -> Tuple[str, str]: """Download a dataset and process its video/image content Args: dataset_id: The dataset ID to download enable_splitting: Whether to enable automatic video splitting progress_callback: Optional callback for progress tracking Returns: Tuple of (loading_msg, status_msg) """ if not dataset_id: logger.warning("No dataset ID provided for download") return "No dataset selected", "Please select a dataset first" try: logger.info(f"Starting download of dataset: {dataset_id}") loading_msg = f"## Downloading dataset: {dataset_id}\n\nThis may take some time depending on the dataset size..." status_msg = f"Downloading dataset: {dataset_id}..." # Get dataset info to check for available files dataset_info = self.hf_api.dataset_info(dataset_id) # Check if there are video files or WebDataset files video_files = [] tar_files = [] siblings = getattr(dataset_info, "siblings", None) or [] if siblings: video_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith((".mp4", ".webm"))] tar_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith(".tar")] # Initialize progress tracking total_files = len(video_files) + len(tar_files) if progress_callback: progress_callback(0, desc=f"Starting download of dataset: {dataset_id}", total=total_files) # Create a temporary directory for downloads with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) files_processed = 0 # If we have video files, download them individually if video_files: loading_msg = f"{loading_msg}\n\nDownloading {len(video_files)} video files..." logger.info(f"Downloading {len(video_files)} video files from {dataset_id}") for i, video_file in enumerate(video_files): # Update progress if progress_callback: progress_callback( files_processed, desc=f"Downloading video {i+1}/{len(video_files)}: {Path(video_file).name}", total=total_files ) # Download the video file try: file_path = hf_hub_download( repo_id=dataset_id, filename=video_file, repo_type="dataset", local_dir=temp_path ) # Look for associated caption file txt_filename = f"{Path(video_file).stem}.txt" txt_path = None for possible_path in [ Path(video_file).with_suffix('.txt').as_posix(), (Path(video_file).parent / txt_filename).as_posix(), ]: try: txt_path = hf_hub_download( repo_id=dataset_id, filename=possible_path, repo_type="dataset", local_dir=temp_path ) logger.info(f"Found caption file for {video_file}: {possible_path}") break except Exception as e: # Caption file doesn't exist at this path, try next logger.debug(f"No caption at {possible_path}: {str(e)}") pass status_msg = f"Downloaded video {i+1}/{len(video_files)} from {dataset_id}" logger.info(status_msg) files_processed += 1 except Exception as e: logger.warning(f"Error downloading {video_file}: {e}") # If we have tar files, download them if tar_files: loading_msg = f"{loading_msg}\n\nDownloading {len(tar_files)} WebDataset files..." logger.info(f"Downloading {len(tar_files)} WebDataset files from {dataset_id}") for i, tar_file in enumerate(tar_files): # Update progress if progress_callback: progress_callback( files_processed, desc=f"Downloading WebDataset {i+1}/{len(tar_files)}: {Path(tar_file).name}", total=total_files ) try: file_path = hf_hub_download( repo_id=dataset_id, filename=tar_file, repo_type="dataset", local_dir=temp_path ) status_msg = f"Downloaded WebDataset {i+1}/{len(tar_files)} from {dataset_id}" logger.info(status_msg) files_processed += 1 except Exception as e: logger.warning(f"Error downloading {tar_file}: {e}") # If no specific files were found, try downloading the entire repo if not video_files and not tar_files: loading_msg = f"{loading_msg}\n\nDownloading entire dataset repository..." logger.info(f"No specific media files found, downloading entire repository for {dataset_id}") if progress_callback: progress_callback(0, desc=f"Downloading entire repository for {dataset_id}", total=1) try: snapshot_download( repo_id=dataset_id, repo_type="dataset", local_dir=temp_path ) status_msg = f"Downloaded entire repository for {dataset_id}" logger.info(status_msg) if progress_callback: progress_callback(1, desc="Repository download complete", total=1) except Exception as e: logger.error(f"Error downloading dataset snapshot: {e}", exc_info=True) return loading_msg, f"Error downloading dataset: {str(e)}" # Process the downloaded files loading_msg = f"{loading_msg}\n\nProcessing downloaded files..." logger.info(f"Processing downloaded files from {dataset_id}") if progress_callback: progress_callback(0, desc="Processing downloaded files", total=100) # Count imported files video_count = 0 image_count = 0 tar_count = 0 # Process function for the event loop async def process_files(): nonlocal video_count, image_count, tar_count # Get total number of files to process file_count = 0 for root, _, files in os.walk(temp_path): file_count += len(files) processed = 0 # Process all files in the temp directory for root, _, files in os.walk(temp_path): for file in files: file_path = Path(root) / file # Update progress (every 5 files to avoid too many updates) if progress_callback and processed % 5 == 0: if file_count > 0: progress_percent = int((processed / file_count) * 100) progress_callback( progress_percent, desc=f"Processing files: {processed}/{file_count}", total=100 ) # Process videos if file.lower().endswith((".mp4", ".webm")): # Choose target path based on auto-splitting setting target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH target_path = target_dir / file_path.name # Make sure filename is unique counter = 1 while target_path.exists(): stem = Path(file_path.name).stem if "___" in stem: base_stem = stem.split("___")[0] else: base_stem = stem target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}" counter += 1 # Copy the video file shutil.copy2(file_path, target_path) logger.info(f"Processed video from dataset: {file_path.name} -> {target_path.name}") # Copy associated caption file if it exists txt_path = file_path.with_suffix('.txt') if txt_path.exists(): shutil.copy2(txt_path, target_path.with_suffix('.txt')) logger.info(f"Copied caption for {file_path.name}") video_count += 1 # Process images elif is_image_file(file_path): target_path = STAGING_PATH / f"{file_path.stem}.{NORMALIZE_IMAGES_TO}" counter = 1 while target_path.exists(): target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}" counter += 1 if normalize_image(file_path, target_path): logger.info(f"Processed image from dataset: {file_path.name} -> {target_path.name}") # Copy caption if available txt_path = file_path.with_suffix('.txt') if txt_path.exists(): caption = txt_path.read_text() caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX) target_path.with_suffix('.txt').write_text(caption) logger.info(f"Processed caption for {file_path.name}") image_count += 1 # Process WebDataset files elif file.lower().endswith(".tar"): # Process the WebDataset archive try: logger.info(f"Processing WebDataset file from dataset: {file}") vid_count, img_count = webdataset_handler.process_webdataset_shard( file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH ) tar_count += 1 video_count += vid_count image_count += img_count logger.info(f"Extracted {vid_count} videos and {img_count} images from {file}") except Exception as e: logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True) processed += 1 # Run the processing asynchronously await process_files() # Update progress to complete if progress_callback: progress_callback(100, desc="Processing complete", total=100) # Generate final status message parts = [] if video_count > 0: parts.append(f"{video_count} video{'s' if video_count != 1 else ''}") if image_count > 0: parts.append(f"{image_count} image{'s' if image_count != 1 else ''}") if tar_count > 0: parts.append(f"{tar_count} WebDataset archive{'s' if tar_count != 1 else ''}") if parts: status = f"Successfully imported {', '.join(parts)} from dataset {dataset_id}" loading_msg = f"{loading_msg}\n\n✅ Success! {status}" logger.info(status) else: status = f"No supported media files found in dataset {dataset_id}" loading_msg = f"{loading_msg}\n\n⚠️ {status}" logger.warning(status) gr.Info(status) return loading_msg, status except Exception as e: error_msg = f"Error downloading dataset {dataset_id}: {str(e)}" logger.error(error_msg, exc_info=True) return f"Error: {error_msg}", error_msg