jbilcke-hf HF staff commited on
Commit
76eb17f
·
1 Parent(s): 222f539

working on adding WebDataset support

Browse files
vms/services/importer.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import shutil
3
  import zipfile
 
4
  import tempfile
5
  import gradio as gr
6
  from pathlib import Path
@@ -8,17 +9,18 @@ from typing import List, Dict, Optional, Tuple
8
  from pytubefix import YouTube
9
  import logging
10
 
11
- from ..config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, TRAINING_PATH, DEFAULT_PROMPT_PREFIX
12
  from ..utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
  class ImportService:
17
  def process_uploaded_files(self, file_paths: List[str]) -> str:
18
- """Process uploaded file (ZIP, MP4, or image)
19
 
20
  Args:
21
- file_paths: File paths to the ploaded files from Gradio
22
 
23
  Returns:
24
  Status message string
@@ -34,6 +36,8 @@ class ImportService:
34
 
35
  if file_ext == '.zip':
36
  return self.process_zip_file(file_path)
 
 
37
  elif file_ext == '.mp4' or file_ext == '.webm':
38
  return self.process_mp4_file(file_path, original_name)
39
  elif is_image_file(file_path):
@@ -86,7 +90,7 @@ class ImportService:
86
  raise gr.Error(f"Error processing image file: {str(e)}")
87
 
88
  def process_zip_file(self, file_path: Path) -> str:
89
- """Process uploaded ZIP file containing media files
90
 
91
  Args:
92
  file_path: Path to the uploaded ZIP file
@@ -97,6 +101,7 @@ class ImportService:
97
  try:
98
  video_count = 0
99
  image_count = 0
 
100
 
101
  # Create temporary directory
102
  with tempfile.TemporaryDirectory() as temp_dir:
@@ -115,7 +120,16 @@ class ImportService:
115
  file_path = Path(root) / file
116
 
117
  try:
118
- if is_video_file(file_path):
 
 
 
 
 
 
 
 
 
119
  # Copy video to videos_to_split
120
  target_path = VIDEOS_TO_SPLIT_PATH / file_path.name
121
  counter = 1
@@ -137,11 +151,13 @@ class ImportService:
137
 
138
  # Copy associated caption file if it exists
139
  txt_path = file_path.with_suffix('.txt')
140
- if txt_path.exists():
141
  if is_video_file(file_path):
142
  shutil.copy2(txt_path, target_path.with_suffix('.txt'))
143
  elif is_image_file(file_path):
144
- shutil.copy2(txt_path, target_path.with_suffix('.txt'))
 
 
145
 
146
  except Exception as e:
147
  logger.error(f"Error processing {file_path.name}: {str(e)}")
@@ -149,21 +165,54 @@ class ImportService:
149
 
150
  # Generate status message
151
  parts = []
 
 
152
  if video_count > 0:
153
- parts.append(f"{video_count} videos")
154
  if image_count > 0:
155
- parts.append(f"{image_count} images")
156
 
157
  if not parts:
158
  return "No supported media files found in ZIP"
159
 
160
- status = f"Successfully stored {' and '.join(parts)}"
161
  gr.Info(status)
162
  return status
163
 
164
  except Exception as e:
165
  raise gr.Error(f"Error processing ZIP: {str(e)}")
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def process_mp4_file(self, file_path: Path, original_name: str) -> str:
168
  """Process a single video file
169
 
 
1
  import os
2
  import shutil
3
  import zipfile
4
+ import tarfile
5
  import tempfile
6
  import gradio as gr
7
  from pathlib import Path
 
9
  from pytubefix import YouTube
10
  import logging
11
 
12
+ from ..config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
13
  from ..utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption
14
+ from ..webdataset import webdataset_handler
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
  class ImportService:
19
  def process_uploaded_files(self, file_paths: List[str]) -> str:
20
+ """Process uploaded file (ZIP, TAR, MP4, or image)
21
 
22
  Args:
23
+ file_paths: File paths to the uploaded files from Gradio
24
 
25
  Returns:
26
  Status message string
 
36
 
37
  if file_ext == '.zip':
38
  return self.process_zip_file(file_path)
39
+ elif file_ext == '.tar':
40
+ return self.process_tar_file(file_path)
41
  elif file_ext == '.mp4' or file_ext == '.webm':
42
  return self.process_mp4_file(file_path, original_name)
43
  elif is_image_file(file_path):
 
90
  raise gr.Error(f"Error processing image file: {str(e)}")
91
 
92
  def process_zip_file(self, file_path: Path) -> str:
93
+ """Process uploaded ZIP file containing media files or WebDataset tar files
94
 
95
  Args:
96
  file_path: Path to the uploaded ZIP file
 
101
  try:
102
  video_count = 0
103
  image_count = 0
104
+ tar_count = 0
105
 
106
  # Create temporary directory
107
  with tempfile.TemporaryDirectory() as temp_dir:
 
120
  file_path = Path(root) / file
121
 
122
  try:
123
+ # Check if it's a WebDataset tar file
124
+ if file.lower().endswith('.tar'):
125
+ # Process WebDataset shard
126
+ vid_count, img_count = webdataset_handler.process_webdataset_shard(
127
+ file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
128
+ )
129
+ video_count += vid_count
130
+ image_count += img_count
131
+ tar_count += 1
132
+ elif is_video_file(file_path):
133
  # Copy video to videos_to_split
134
  target_path = VIDEOS_TO_SPLIT_PATH / file_path.name
135
  counter = 1
 
151
 
152
  # Copy associated caption file if it exists
153
  txt_path = file_path.with_suffix('.txt')
154
+ if txt_path.exists() and not file.lower().endswith('.tar'):
155
  if is_video_file(file_path):
156
  shutil.copy2(txt_path, target_path.with_suffix('.txt'))
157
  elif is_image_file(file_path):
158
+ caption = txt_path.read_text()
159
+ caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
160
+ target_path.with_suffix('.txt').write_text(caption)
161
 
162
  except Exception as e:
163
  logger.error(f"Error processing {file_path.name}: {str(e)}")
 
165
 
166
  # Generate status message
167
  parts = []
168
+ if tar_count > 0:
169
+ parts.append(f"{tar_count} WebDataset shard{'s' if tar_count != 1 else ''}")
170
  if video_count > 0:
171
+ parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
172
  if image_count > 0:
173
+ parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
174
 
175
  if not parts:
176
  return "No supported media files found in ZIP"
177
 
178
+ status = f"Successfully stored {', '.join(parts)}"
179
  gr.Info(status)
180
  return status
181
 
182
  except Exception as e:
183
  raise gr.Error(f"Error processing ZIP: {str(e)}")
184
 
185
+ def process_tar_file(self, file_path: Path) -> str:
186
+ """Process a WebDataset tar file
187
+
188
+ Args:
189
+ file_path: Path to the uploaded tar file
190
+
191
+ Returns:
192
+ Status message string
193
+ """
194
+ try:
195
+ video_count, image_count = webdataset_handler.process_webdataset_shard(
196
+ file_path, VIDEOS_TO_SPLIT_PATH, STAGING_PATH
197
+ )
198
+
199
+ # Generate status message
200
+ parts = []
201
+ if video_count > 0:
202
+ parts.append(f"{video_count} video{'s' if video_count != 1 else ''}")
203
+ if image_count > 0:
204
+ parts.append(f"{image_count} image{'s' if image_count != 1 else ''}")
205
+
206
+ if not parts:
207
+ return "No supported media files found in WebDataset"
208
+
209
+ status = f"Successfully extracted {' and '.join(parts)} from WebDataset"
210
+ gr.Info(status)
211
+ return status
212
+
213
+ except Exception as e:
214
+ raise gr.Error(f"Error processing WebDataset tar file: {str(e)}")
215
+
216
  def process_mp4_file(self, file_path: Path, original_name: str) -> str:
217
  """Process a single video file
218
 
vms/tabs/import_tab.py CHANGED
@@ -47,16 +47,17 @@ class ImportTab(BaseTab):
47
  with gr.Column(scale=3):
48
  with gr.Row():
49
  with gr.Column():
50
- gr.Markdown("## Import video files")
51
  gr.Markdown("You can upload either:")
52
  gr.Markdown("- A single MP4 video file")
53
- gr.Markdown("- A ZIP archive containing multiple videos and optional caption files")
54
- gr.Markdown("For ZIP files: Create a folder containing videos (name is not important) and optional caption files with the same name (eg. `some_video.txt` for `some_video.mp4`)")
 
55
 
56
  with gr.Row():
57
  self.components["files"] = gr.Files(
58
- label="Upload Images, Videos or ZIP",
59
- file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip"],
60
  type="filepath"
61
  )
62
 
 
47
  with gr.Column(scale=3):
48
  with gr.Row():
49
  with gr.Column():
50
+ gr.Markdown("## Import files")
51
  gr.Markdown("You can upload either:")
52
  gr.Markdown("- A single MP4 video file")
53
+ gr.Markdown("- A ZIP archive containing multiple videos/images and optional caption files")
54
+ gr.Markdown("- A WebDataset shard (.tar file)")
55
+ gr.Markdown("- A ZIP archive containing WebDataset shards (.tar files)")
56
 
57
  with gr.Row():
58
  self.components["files"] = gr.Files(
59
+ label="Upload Images, Videos, ZIP or WebDataset",
60
+ file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip", ".tar"],
61
  type="filepath"
62
  )
63
 
vms/utils/__init__.py CHANGED
@@ -6,6 +6,8 @@ from .image_preprocessing import normalize_image
6
  from .video_preprocessing import remove_black_bars
7
  from .finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
8
 
 
 
9
  __all__ = [
10
  'validate_model_repo',
11
  'make_archive',
@@ -30,4 +32,6 @@ __all__ = [
30
 
31
  'prepare_finetrainers_dataset',
32
  'copy_files_to_training_dir',
 
 
33
  ]
 
6
  from .video_preprocessing import remove_black_bars
7
  from .finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
8
 
9
+ from . import webdataset_handler
10
+
11
  __all__ = [
12
  'validate_model_repo',
13
  'make_archive',
 
32
 
33
  'prepare_finetrainers_dataset',
34
  'copy_files_to_training_dir',
35
+
36
+ 'webdataset_handler'n
37
  ]
vms/utils/webdataset_handler.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WebDataset format handling for Video Model Studio
3
+ """
4
+
5
+ import os
6
+ import tarfile
7
+ import tempfile
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import List, Dict, Tuple, Optional
11
+
12
+ from ..utils import is_image_file, is_video_file, extract_scene_info
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ def is_webdataset_file(file_path: Path) -> bool:
17
+ """Check if file is a WebDataset tar file
18
+
19
+ Args:
20
+ file_path: Path to check
21
+
22
+ Returns:
23
+ bool: True if file has .tar extension
24
+ """
25
+ return file_path.suffix.lower() == '.tar'
26
+
27
+ def process_webdataset_shard(
28
+ tar_path: Path,
29
+ videos_output_dir: Path,
30
+ staging_output_dir: Path
31
+ ) -> Tuple[int, int]:
32
+ """Process a WebDataset shard (tar file) extracting video/image and caption pairs
33
+
34
+ Args:
35
+ tar_path: Path to the WebDataset tar file
36
+ videos_output_dir: Directory to store videos for splitting
37
+ staging_output_dir: Directory to store images and captions
38
+
39
+ Returns:
40
+ Tuple of (video_count, image_count)
41
+ """
42
+ video_count = 0
43
+ image_count = 0
44
+
45
+ try:
46
+ # Dictionary to store grouped files by prefix
47
+ grouped_files = {}
48
+
49
+ # First pass: collect and group files by prefix
50
+ with tarfile.open(tar_path, 'r') as tar:
51
+ for member in tar.getmembers():
52
+ if member.isdir():
53
+ continue
54
+
55
+ # Skip hidden files
56
+ if os.path.basename(member.name).startswith('.'):
57
+ continue
58
+
59
+ # Extract file prefix (everything up to the first dot after the last slash)
60
+ file_path = Path(member.name)
61
+ file_name = file_path.name
62
+
63
+ # Get prefix (filename without extensions)
64
+ # For WebDataset, the prefix is everything up to the first dot
65
+ prefix_parts = file_name.split('.', 1)
66
+ if len(prefix_parts) < 2:
67
+ # No extension, skip
68
+ continue
69
+
70
+ prefix = prefix_parts[0]
71
+ extension = '.' + prefix_parts[1]
72
+
73
+ # Include directory in the prefix to keep samples grouped correctly
74
+ full_prefix = str(file_path.parent / prefix) if file_path.parent != Path('.') else prefix
75
+
76
+ if full_prefix not in grouped_files:
77
+ grouped_files[full_prefix] = []
78
+
79
+ grouped_files[full_prefix].append((member, extension))
80
+
81
+ # Second pass: extract and process grouped files
82
+ with tarfile.open(tar_path, 'r') as tar:
83
+ for prefix, members in grouped_files.items():
84
+ # Create safe filename from prefix
85
+ safe_prefix = Path(prefix).name
86
+
87
+ # Find media and caption files
88
+ media_file = None
89
+ caption_file = None
90
+ media_ext = None
91
+
92
+ for member, ext in members:
93
+ if ext.lower() in ['.jpg', '.jpeg', '.png', '.webp', '.avif', '.heic']:
94
+ media_file = member
95
+ media_ext = ext
96
+ elif ext.lower() in ['.mp4', '.webm']:
97
+ media_file = member
98
+ media_ext = ext
99
+ elif ext.lower() in ['.txt', '.caption', '.json', '.cls']:
100
+ caption_file = member
101
+
102
+ # If we have a media file, process it
103
+ if media_file:
104
+ # Determine if it's video or image
105
+ is_video = media_ext.lower() in ['.mp4', '.webm']
106
+
107
+ # Choose target directory based on media type
108
+ target_dir = videos_output_dir if is_video else staging_output_dir
109
+
110
+ # Create target filename
111
+ target_filename = f"{safe_prefix}{media_ext}"
112
+ target_path = target_dir / target_filename
113
+
114
+ # If file already exists, add number suffix
115
+ counter = 1
116
+ while target_path.exists():
117
+ target_path = target_dir / f"{safe_prefix}___{counter}{media_ext}"
118
+ counter += 1
119
+
120
+ # Extract media file
121
+ with open(target_path, 'wb') as f:
122
+ f.write(tar.extractfile(media_file).read())
123
+
124
+ # If we have a caption file, extract it too
125
+ if caption_file:
126
+ caption_text = tar.extractfile(caption_file).read().decode('utf-8', errors='ignore')
127
+
128
+ # Save caption with media file extension
129
+ caption_path = target_path.with_suffix('.txt')
130
+ with open(caption_path, 'w', encoding='utf-8') as f:
131
+ f.write(caption_text)
132
+
133
+ # Update counters
134
+ if is_video:
135
+ video_count += 1
136
+ else:
137
+ image_count += 1
138
+
139
+ except Exception as e:
140
+ logger.error(f"Error processing WebDataset file {tar_path}: {e}")
141
+ raise
142
+
143
+ return video_count, image_count