File size: 30,083 Bytes
b613c3c
 
 
 
 
 
 
 
 
 
 
 
aa1e877
b613c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa1e877
b613c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89bbef2
b613c3c
 
 
 
 
 
 
 
 
 
 
 
89bbef2
 
b613c3c
89bbef2
 
b613c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa1e877
 
 
 
246c64e
aa1e877
 
b613c3c
 
 
 
 
 
aa1e877
b613c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa1e877
 
 
 
 
b613c3c
 
 
 
 
 
 
 
 
 
aa1e877
b613c3c
aa1e877
 
 
 
 
 
 
 
b613c3c
 
 
 
 
 
 
 
 
 
aa1e877
b613c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa1e877
 
 
 
b613c3c
 
aa1e877
b613c3c
 
 
 
 
 
 
 
aa1e877
b613c3c
aa1e877
 
 
 
 
 
 
 
b613c3c
aa1e877
b613c3c
 
 
 
aa1e877
b613c3c
 
aa1e877
 
 
246c64e
aa1e877
 
b613c3c
 
 
 
 
aa1e877
b613c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa1e877
 
 
 
 
b613c3c
 
 
aa1e877
b613c3c
 
 
 
 
 
 
aa1e877
 
 
 
 
 
 
 
b613c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa1e877
b613c3c
 
 
 
 
 
 
 
 
aa1e877
 
 
 
 
 
 
 
b613c3c
 
 
 
 
 
 
 
 
aa1e877
b613c3c
 
 
 
 
 
 
 
aa1e877
 
 
b613c3c
 
 
 
 
 
 
 
aa1e877
 
 
b613c3c
 
 
 
 
 
 
 
aa1e877
 
 
b613c3c
 
 
 
 
 
 
 
 
aa1e877
 
 
 
 
 
 
b613c3c
 
 
 
 
aa1e877
 
 
 
 
 
 
 
 
 
b613c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa1e877
 
b613c3c
 
 
 
aa1e877
 
 
 
b613c3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
"""
Hugging Face Hub dataset browser for Video Model Studio.
Handles searching, viewing, and downloading datasets from the Hub.
"""

import os
import shutil
import tempfile
import asyncio
import logging
import gradio as gr
from pathlib import Path
from typing import List, Dict, Optional, Tuple, Any, Union, Callable

from huggingface_hub import (
    HfApi, 
    hf_hub_download, 
    snapshot_download, 
    list_datasets
)

from vms.config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
from vms.utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler

logger = logging.getLogger(__name__)

class HubDatasetBrowser:
    """Handles interactions with Hugging Face Hub datasets"""
    
    def __init__(self, hf_api: HfApi):
        """Initialize with HfApi instance
        
        Args:
            hf_api: Hugging Face Hub API instance
        """
        self.hf_api = hf_api
    
    def search_datasets(self, query: str) -> List[List[str]]:
        """Search for datasets on the Hugging Face Hub
        
        Args:
            query: Search query string
            
        Returns:
            List of datasets matching the query [id, title, downloads]
            Note: We still return all columns internally, but the UI will only display the first column
        """
        try:
            # Start with some filters to find video-related datasets
            search_terms = query.strip() if query and query.strip() else "video"
            logger.info(f"Searching datasets with query: '{search_terms}'")
            
            # Fetch datasets that match the search
            datasets = list(self.hf_api.list_datasets(
                search=search_terms,
                limit=50
            ))
            
            # Format results for display
            results = []
            for ds in datasets:
                # Extract relevant information
                dataset_id = ds.id
                
                # Safely get the title with fallbacks
                card_data = getattr(ds, "card_data", None)
                title = ""
                
                if card_data is not None and isinstance(card_data, dict):
                    title = card_data.get("name", "")
                
                if not title:
                    # Use the last part of the repo ID as a fallback
                    title = dataset_id.split("/")[-1]
                
                # Safely get downloads
                downloads = getattr(ds, "downloads", 0)
                if downloads is None:
                    downloads = 0
                
                results.append([dataset_id, title, downloads])
            
            # Sort by downloads (most downloaded first)
            results.sort(key=lambda x: x[2] if x[2] is not None else 0, reverse=True)
            
            logger.info(f"Found {len(results)} datasets matching '{search_terms}'")
            return results
        
        except Exception as e:
            logger.error(f"Error searching datasets: {str(e)}", exc_info=True)
            return [[f"Error: {str(e)}", "", ""]]
            
    def get_dataset_info(self, dataset_id: str) -> Tuple[str, Dict[str, int], Dict[str, List[str]]]:
        """Get detailed information about a dataset
        
        Args:
            dataset_id: The dataset ID to get information for
            
        Returns:
            Tuple of (markdown_info, file_counts, file_groups)
            - markdown_info: Markdown formatted string with dataset information
            - file_counts: Dictionary with counts of each file type
            - file_groups: Dictionary with lists of filenames grouped by type
        """
        try:
            if not dataset_id:
                logger.warning("No dataset ID provided to get_dataset_info")
                return "No dataset selected", {}, {}
                
            logger.info(f"Getting info for dataset: {dataset_id}")
                
            # Get detailed information about the dataset
            dataset_info = self.hf_api.dataset_info(dataset_id)
            
            # Format the information for display
            info_text = f"### {dataset_info.id}\n\n"
            
            # Add description if available (with safer access)
            card_data = getattr(dataset_info, "card_data", None)
            description = ""
            
            if card_data is not None and isinstance(card_data, dict):
                description = card_data.get("description", "")
                
            if description:
                info_text += f"{description[:500]}{'...' if len(description) > 500 else ''}\n\n"
            
            # Add basic stats (with safer access)
            #downloads = getattr(dataset_info, 'downloads', None)
            #info_text += f"## Downloads: {downloads if downloads is not None else 'N/A'}\n"
            
            #last_modified = getattr(dataset_info, 'last_modified', None)
            #info_text += f"## Last modified: {last_modified if last_modified is not None else 'N/A'}\n"
            
            # Group files by type
            file_groups = {
                "video": [],
                "webdataset": []
            }
            
            siblings = getattr(dataset_info, "siblings", None) or []
            
            # Extract files by type
            for s in siblings:
                if not hasattr(s, 'rfilename'):
                    continue
                    
                filename = s.rfilename
                if filename.lower().endswith((".mp4", ".webm")):
                    file_groups["video"].append(filename)
                elif filename.lower().endswith(".tar"):
                    file_groups["webdataset"].append(filename)
            
            # Create file counts dictionary
            file_counts = {
                "video": len(file_groups["video"]),
                "webdataset": len(file_groups["webdataset"])
            }
            
            logger.info(f"Successfully retrieved info for dataset: {dataset_id}")
            return info_text, file_counts, file_groups
            
        except Exception as e:
            logger.error(f"Error getting dataset info: {str(e)}", exc_info=True)
            return f"Error loading dataset information: {str(e)}", {}, {}
    
    async def download_file_group(
        self, 
        dataset_id: str, 
        file_type: str, 
        enable_splitting: bool,
        progress_callback: Optional[Callable] = None
    ) -> str:
        """Download all files of a specific type from the dataset
        
        Args:
            dataset_id: The dataset ID
            file_type: Either "video" or "webdataset"
            enable_splitting: Whether to enable automatic video splitting
            progress_callback: Optional callback for progress updates
            
        Returns:
            Status message
        """
        try:
            # Get dataset info to retrieve file list
            _, _, file_groups = self.get_dataset_info(dataset_id)
            
            # Get the list of files for the specified type
            files = file_groups.get(file_type, [])
            
            if not files:
                return f"No {file_type} files found in the dataset"
            
            logger.info(f"Downloading {len(files)} {file_type} files from dataset {dataset_id}")
            gr.Info(f"Starting download of {len(files)} {file_type} files from {dataset_id}")
            
            # Initialize progress if callback provided
            if progress_callback:
                progress_callback(0, desc=f"Starting download of {len(files)} {file_type} files", total=len(files))
            
            # Track counts for status message
            video_count = 0
            image_count = 0
            
            # Create a temporary directory for downloads
            with tempfile.TemporaryDirectory() as temp_dir:
                temp_path = Path(temp_dir)
                
                # Process all files of the requested type
                for i, filename in enumerate(files):
                    try:
                        # Update progress
                        if progress_callback:
                            progress_callback(
                                i, 
                                desc=f"Downloading file {i+1}/{len(files)}: {Path(filename).name}",
                                total=len(files)
                            )
                        
                        # Download the file
                        file_path = hf_hub_download(
                            repo_id=dataset_id,
                            filename=filename,
                            repo_type="dataset",
                            local_dir=temp_path
                        )
                        
                        file_path = Path(file_path)
                        logger.info(f"Downloaded file to {file_path}")
                        #gr.Info(f"Downloaded {file_path.name} ({i+1}/{len(files)})")
                        
                        # Process based on file type
                        if file_type == "video":
                            # Choose target directory based on auto-splitting setting
                            target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
                            target_path = target_dir / file_path.name
                            
                            # Make sure filename is unique
                            counter = 1
                            while target_path.exists():
                                stem = Path(file_path.name).stem
                                if "___" in stem:
                                    base_stem = stem.split("___")[0]
                                else:
                                    base_stem = stem
                                target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}"
                                counter += 1
                                
                            # Copy the video file
                            shutil.copy2(file_path, target_path)
                            logger.info(f"Processed video: {file_path.name} -> {target_path.name}")
                            
                            # Try to download caption if it exists
                            try:
                                txt_filename = f"{Path(filename).stem}.txt"
                                for possible_path in [
                                    Path(filename).with_suffix('.txt').as_posix(),
                                    (Path(filename).parent / txt_filename).as_posix(),
                                ]:
                                    try:
                                        txt_path = hf_hub_download(
                                            repo_id=dataset_id,
                                            filename=possible_path,
                                            repo_type="dataset",
                                            local_dir=temp_path
                                        )
                                        shutil.copy2(txt_path, target_path.with_suffix('.txt'))
                                        logger.info(f"Copied caption for {file_path.name}")
                                        break
                                    except Exception:
                                        # Caption file doesn't exist at this path, try next
                                        pass
                            except Exception as e:
                                logger.warning(f"Error trying to download caption: {e}")
                            
                            video_count += 1
                            
                        elif file_type == "webdataset":
                            # Process the WebDataset archive
                            try:
                                logger.info(f"Processing WebDataset file: {file_path}")
                                vid_count, img_count = webdataset_handler.process_webdataset_shard(
                                    file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
                                )
                                video_count += vid_count
                                image_count += img_count
                            except Exception as e:
                                logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
                    
                    except Exception as e:
                        logger.warning(f"Error processing file {filename}: {e}")
                
                # Update progress to complete
                if progress_callback:
                    progress_callback(len(files), desc="Download complete", total=len(files))
                
                # Generate status message
                if file_type == "video":
                    status_msg = f"Successfully imported {video_count} videos from dataset {dataset_id}"
                elif file_type == "webdataset":
                    parts = []
                    if video_count > 0:
                        parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
                    if image_count > 0:
                        parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
                        
                    if parts:
                        status_msg = f"Successfully imported {' and '.join(parts)} from WebDataset archives"
                    else:
                        status_msg = f"No media was found in the WebDataset archives"
                else:
                    status_msg = f"Unknown file type: {file_type}"
                
                # Final notification
                logger.info(f"✅ Download complete! {status_msg}")
                # This info message will appear as a toast notification
                gr.Info(f"✅ Download complete! {status_msg}")
                
                return status_msg
                
        except Exception as e:
            error_msg = f"Error downloading {file_type} files: {str(e)}"
            logger.error(error_msg, exc_info=True)
            gr.Error(error_msg)
            return error_msg
    
    async def download_dataset(
        self, 
        dataset_id: str, 
        enable_splitting: bool,
        progress_callback: Optional[Callable] = None
    ) -> Tuple[str, str]:
        """Download a dataset and process its video/image content
        
        Args:
            dataset_id: The dataset ID to download
            enable_splitting: Whether to enable automatic video splitting
            progress_callback: Optional callback for progress tracking
            
        Returns:
            Tuple of (loading_msg, status_msg)
        """
        if not dataset_id:
            logger.warning("No dataset ID provided for download")
            return "No dataset selected", "Please select a dataset first"
        
        try:
            logger.info(f"Starting download of dataset: {dataset_id}")
            loading_msg = f"## Downloading dataset: {dataset_id}\n\nThis may take some time depending on the dataset size..."
            status_msg = f"Downloading dataset: {dataset_id}..."
            
            # Get dataset info to check for available files
            dataset_info = self.hf_api.dataset_info(dataset_id)
            
            # Check if there are video files or WebDataset files
            video_files = []
            tar_files = []
            
            siblings = getattr(dataset_info, "siblings", None) or []
            if siblings:
                video_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith((".mp4", ".webm"))]
                tar_files = [s.rfilename for s in siblings if hasattr(s, 'rfilename') and s.rfilename.lower().endswith(".tar")]
            
            # Initialize progress tracking
            total_files = len(video_files) + len(tar_files)
            if progress_callback:
                progress_callback(0, desc=f"Starting download of dataset: {dataset_id}", total=total_files)
            
            # Create a temporary directory for downloads
            with tempfile.TemporaryDirectory() as temp_dir:
                temp_path = Path(temp_dir)
                files_processed = 0
                
                # If we have video files, download them individually
                if video_files:
                    loading_msg = f"{loading_msg}\n\nDownloading {len(video_files)} video files..."
                    logger.info(f"Downloading {len(video_files)} video files from {dataset_id}")
                    
                    for i, video_file in enumerate(video_files):
                        # Update progress
                        if progress_callback:
                            progress_callback(
                                files_processed, 
                                desc=f"Downloading video {i+1}/{len(video_files)}: {Path(video_file).name}",
                                total=total_files
                            )
                            
                        # Download the video file
                        try:
                            file_path = hf_hub_download(
                                repo_id=dataset_id,
                                filename=video_file,
                                repo_type="dataset",
                                local_dir=temp_path
                            )
                            
                            # Look for associated caption file
                            txt_filename = f"{Path(video_file).stem}.txt"
                            txt_path = None
                            for possible_path in [
                                Path(video_file).with_suffix('.txt').as_posix(),
                                (Path(video_file).parent / txt_filename).as_posix(),
                            ]:
                                try:
                                    txt_path = hf_hub_download(
                                        repo_id=dataset_id,
                                        filename=possible_path,
                                        repo_type="dataset",
                                        local_dir=temp_path
                                    )
                                    logger.info(f"Found caption file for {video_file}: {possible_path}")
                                    break
                                except Exception as e:
                                    # Caption file doesn't exist at this path, try next
                                    logger.debug(f"No caption at {possible_path}: {str(e)}")
                                    pass
                                
                            status_msg = f"Downloaded video {i+1}/{len(video_files)} from {dataset_id}"
                            logger.info(status_msg)
                            files_processed += 1
                        except Exception as e:
                            logger.warning(f"Error downloading {video_file}: {e}")
                
                # If we have tar files, download them
                if tar_files:
                    loading_msg = f"{loading_msg}\n\nDownloading {len(tar_files)} WebDataset files..."
                    logger.info(f"Downloading {len(tar_files)} WebDataset files from {dataset_id}")
                    
                    for i, tar_file in enumerate(tar_files):
                        # Update progress
                        if progress_callback:
                            progress_callback(
                                files_processed, 
                                desc=f"Downloading WebDataset {i+1}/{len(tar_files)}: {Path(tar_file).name}",
                                total=total_files
                            )
                            
                        try:
                            file_path = hf_hub_download(
                                repo_id=dataset_id,
                                filename=tar_file,
                                repo_type="dataset",
                                local_dir=temp_path
                            )
                            status_msg = f"Downloaded WebDataset {i+1}/{len(tar_files)} from {dataset_id}"
                            logger.info(status_msg)
                            files_processed += 1
                        except Exception as e:
                            logger.warning(f"Error downloading {tar_file}: {e}")
                
                # If no specific files were found, try downloading the entire repo
                if not video_files and not tar_files:
                    loading_msg = f"{loading_msg}\n\nDownloading entire dataset repository..."
                    logger.info(f"No specific media files found, downloading entire repository for {dataset_id}")
                    
                    if progress_callback:
                        progress_callback(0, desc=f"Downloading entire repository for {dataset_id}", total=1)
                    
                    try:
                        snapshot_download(
                            repo_id=dataset_id,
                            repo_type="dataset",
                            local_dir=temp_path
                        )
                        status_msg = f"Downloaded entire repository for {dataset_id}"
                        logger.info(status_msg)
                        
                        if progress_callback:
                            progress_callback(1, desc="Repository download complete", total=1)
                    except Exception as e:
                        logger.error(f"Error downloading dataset snapshot: {e}", exc_info=True)
                        return loading_msg, f"Error downloading dataset: {str(e)}"
                
                # Process the downloaded files
                loading_msg = f"{loading_msg}\n\nProcessing downloaded files..."
                logger.info(f"Processing downloaded files from {dataset_id}")
                
                if progress_callback:
                    progress_callback(0, desc="Processing downloaded files", total=100)
                
                # Count imported files
                video_count = 0
                image_count = 0
                tar_count = 0
                
                # Process function for the event loop
                async def process_files():
                    nonlocal video_count, image_count, tar_count
                    
                    # Get total number of files to process
                    file_count = 0
                    for root, _, files in os.walk(temp_path):
                        file_count += len(files)
                    
                    processed = 0
                    
                    # Process all files in the temp directory
                    for root, _, files in os.walk(temp_path):
                        for file in files:
                            file_path = Path(root) / file
                            
                            # Update progress (every 5 files to avoid too many updates)
                            if progress_callback and processed % 5 == 0:
                                if file_count > 0:
                                    progress_percent = int((processed / file_count) * 100)
                                    progress_callback(
                                        progress_percent, 
                                        desc=f"Processing files: {processed}/{file_count}",
                                        total=100
                                    )
                            
                            # Process videos
                            if file.lower().endswith((".mp4", ".webm")):
                                # Choose target path based on auto-splitting setting
                                target_dir = VIDEOS_TO_SPLIT_PATH if enable_splitting else STAGING_PATH
                                target_path = target_dir / file_path.name
                                
                                # Make sure filename is unique
                                counter = 1
                                while target_path.exists():
                                    stem = Path(file_path.name).stem
                                    if "___" in stem:
                                        base_stem = stem.split("___")[0]
                                    else:
                                        base_stem = stem
                                    target_path = target_dir / f"{base_stem}___{counter}{Path(file_path.name).suffix}"
                                    counter += 1
                                    
                                # Copy the video file
                                shutil.copy2(file_path, target_path)
                                logger.info(f"Processed video from dataset: {file_path.name} -> {target_path.name}")
                                
                                # Copy associated caption file if it exists
                                txt_path = file_path.with_suffix('.txt')
                                if txt_path.exists():
                                    shutil.copy2(txt_path, target_path.with_suffix('.txt'))
                                    logger.info(f"Copied caption for {file_path.name}")
                                    
                                video_count += 1
                                
                            # Process images
                            elif is_image_file(file_path):
                                target_path = STAGING_PATH / f"{file_path.stem}.{NORMALIZE_IMAGES_TO}"
                                
                                counter = 1
                                while target_path.exists():
                                    target_path = STAGING_PATH / f"{file_path.stem}___{counter}.{NORMALIZE_IMAGES_TO}"
                                    counter += 1
                                    
                                if normalize_image(file_path, target_path):
                                    logger.info(f"Processed image from dataset: {file_path.name} -> {target_path.name}")
                                    
                                    # Copy caption if available
                                    txt_path = file_path.with_suffix('.txt')
                                    if txt_path.exists():
                                        caption = txt_path.read_text()
                                        caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
                                        target_path.with_suffix('.txt').write_text(caption)
                                        logger.info(f"Processed caption for {file_path.name}")
                                    
                                    image_count += 1
                                
                            # Process WebDataset files
                            elif file.lower().endswith(".tar"):
                                # Process the WebDataset archive
                                try:
                                    logger.info(f"Processing WebDataset file from dataset: {file}")
                                    vid_count, img_count = webdataset_handler.process_webdataset_shard(
                                        file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
                                    )
                                    tar_count += 1
                                    video_count += vid_count
                                    image_count += img_count
                                    logger.info(f"Extracted {vid_count} videos and {img_count} images from {file}")
                                except Exception as e:
                                    logger.error(f"Error processing WebDataset file {file_path}: {str(e)}", exc_info=True)
                            
                            processed += 1
                                    
                # Run the processing asynchronously
                await process_files()
                
                # Update progress to complete
                if progress_callback:
                    progress_callback(100, desc="Processing complete", total=100)
                
                # Generate final status message
                parts = []
                if video_count > 0:
                    parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
                if image_count > 0:
                    parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
                if tar_count > 0:
                    parts.append(f"{tar_count} WebDataset archive{'s' if tar_count != 1 else ''}")
                
                if parts:
                    status = f"Successfully imported {', '.join(parts)} from dataset {dataset_id}"
                    loading_msg = f"{loading_msg}\n\n✅ Success! {status}"
                    logger.info(status)
                else:
                    status = f"No supported media files found in dataset {dataset_id}"
                    loading_msg = f"{loading_msg}\n\n⚠️ {status}"
                    logger.warning(status)
                
                gr.Info(status)
                return loading_msg, status
                
        except Exception as e:
            error_msg = f"Error downloading dataset {dataset_id}: {str(e)}"
            logger.error(error_msg, exc_info=True)
            return f"Error: {error_msg}", error_msg