Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import copy | |
import io | |
import itertools | |
import os | |
import re | |
from functools import partial | |
from typing import Any, Callable, Dict, Iterable, List, Optional | |
import braceexpand | |
import numpy as np | |
import torch | |
import webdataset as wds | |
from PIL import Image | |
from torch.utils.data import IterableDataset | |
from torch.utils.data._utils.collate import default_collate | |
from torchvision import transforms | |
from webdataset.filters import pipelinefilter, reraise_exception | |
from webdataset.handlers import warn_and_continue | |
try: | |
# Optionally load huggingface datasets | |
from datasets import load_dataset | |
from datasets.distributed import split_dataset_by_node | |
except ImportError: | |
print("Huggingface datasets not installed. Please install with `pip install datasets`.") | |
from fourm.data.masking import TransferMasking, UnifiedMasking | |
from fourm.data.modality_transforms import (CropSettingsTransform, IdentityTransform, | |
MaskTransform, UnifiedDataTransform, | |
get_transform_key) | |
from fourm.data.multimodal_dataset_folder import MultiModalDatasetFolder | |
from fourm.utils.dist import get_rank, get_world_size | |
def build_fm_pretraining_dataset( | |
data_path, all_domains, modality_info, modality_transforms, | |
image_augmenter, text_tokenizer, | |
input_tokens_range, target_tokens_range, | |
sampling_weights=None): | |
"""Builds the FourM pre-training dataset based on the given arguments. | |
This function should mainly used for smaller datasets (e.g. validation sets), | |
while large training sets should be loaded with build_wds_fm_pretraining_dataloader in webdataset format. | |
Args: | |
data_path: Path to the dataset. | |
all_domains: List of all modalities to be used. | |
modality_info: Dictionary containing information about the modalities. | |
modality_transforms: Dictionary containing the transforms for each modality. | |
image_augmenter: Image augmenter. | |
text_tokenizer: Text tokenizer (for sequence modalities). | |
input_tokens_range: Range of the input token budget. | |
target_tokens_range: Range of the target token budget. | |
sampling_weights: Sampling weights for the mixture of Dirichlet distributions. | |
Returns: | |
FourM pre-training dataset as a PyTorch Dataset. | |
""" | |
transform = transforms.Compose([ | |
UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter), | |
UnifiedMasking(modality_info=modality_info, text_tokenizer=text_tokenizer, | |
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range, | |
sampling_weights=sampling_weights), | |
]) | |
# Remove vq domains that require a tokenizer | |
modalities_without_vq = [mod for mod in all_domains if not modality_info[mod].get("requires_tokenizer", False)] | |
# If we are using a pre-tokenized modality, we default to pre-computed crop settings | |
if any([modality_info[domain].get("pretokenized", False) for domain in all_domains]): | |
modalities_without_vq.append("crop_settings") | |
modality_transforms = copy.deepcopy(modality_transforms) | |
modality_transforms["crop_settings"] = CropSettingsTransform() | |
modality_paths = {mod: modality_info[mod]['path'] for mod in modality_info if modality_info[mod].get('path', None) is not None} | |
return MultiModalDatasetFolder(root=data_path, modalities=modalities_without_vq, modality_paths=modality_paths, | |
modality_transforms=modality_transforms, transform=transform) | |
def build_fm_transfer_dataset( | |
data_path, modality_info, transform, modality_transforms, all_domains, | |
load_mask_valid: bool = False, max_samples: Optional[int] = None, | |
pre_shuffle: bool = False, cache: bool = False): | |
"""Builds the FourM transfer dataset based on the given arguments. | |
Args: | |
data_path: Path to the dataset. | |
modality_info: Dictionary containing information about the modalities. | |
transform: Transform to be applied to the dataset. | |
modality_transforms: Dictionary containing the transforms for each modality. | |
all_domains: List of all modalities to be used. | |
load_mask_valid: Whether to load the mask_valid "modality". | |
max_samples: Maximum number of samples to load. | |
pre_shuffle: Whether to shuffle the dataset before loading. | |
cache: Whether to cache the dataset in memory. | |
Returns: | |
FourM transfer dataset as a PyTorch Dataset. | |
""" | |
# Remove vq domains that require a tokenizer | |
modalities_without_vq = [mod for mod in all_domains if not modality_info[mod].get("requires_tokenizer", False)] | |
# If we are using a pre-tokenized modality, we default to pre-computed crop settings | |
if any([modality_info[domain].get("pretokenized", False) for domain in all_domains]): | |
modalities_without_vq.append("crop_settings") | |
modality_transforms = copy.deepcopy(modality_transforms) | |
modality_transforms["crop_settings"] = CropSettingsTransform() | |
if load_mask_valid: | |
modalities_without_vq.append("mask_valid") | |
modality_transforms = copy.deepcopy(modality_transforms) | |
modality_transforms["mask_valid"] = MaskTransform() | |
modality_paths = {mod: modality_info[mod]['path'] for mod in modality_info if modality_info[mod].get('path', None) is not None} | |
return MultiModalDatasetFolder(root=data_path, modalities=modalities_without_vq, modality_paths=modality_paths, | |
modality_transforms=modality_transforms, transform=transform, max_samples=max_samples, | |
pre_shuffle=pre_shuffle, cache=cache) | |
### Webdatasets (wds) functions | |
def _keyless_map(data, f, handler=reraise_exception): | |
"""Map samples without adding __key__.""" | |
for sample in data: | |
try: | |
result = f(sample) | |
except Exception as exn: | |
if handler(exn): | |
continue | |
else: | |
break | |
if result is None: | |
continue | |
yield result | |
map = pipelinefilter(_keyless_map) | |
def check_dots(s): | |
if '.gz' in s: | |
return s.count('.') == 2 | |
return s.count('.') == 1 | |
def remove_ext_with_gz(s): | |
if s.endswith('.gz'): | |
s = s.replace(".gz", "") | |
return os.path.splitext(s)[0] | |
def wds_decoder(key, value): | |
if key == "png" or key.endswith(".png"): | |
img = Image.open(io.BytesIO(value)) | |
return img | |
elif key == "jpg" or key.endswith(".jpg"): | |
img = Image.open(io.BytesIO(value)) | |
return img | |
elif key == "jpeg" or key.endswith(".jpeg"): | |
img = Image.open(io.BytesIO(value)) | |
return img | |
elif key == 'npy' or key.endswith("npy"): | |
content = np.load(io.BytesIO(value), allow_pickle=True) | |
# try: | |
# content = np.load(io.BytesIO(value)) | |
# except: | |
# content = np.load(io.BytesIO(value), allow_pickle=True) | |
return content | |
elif key == "jpx" or key.endswith('.jpx'): | |
img = Image.open(io.BytesIO(value)) | |
return img | |
elif 'output' in key: | |
return int(value) | |
else: | |
# If not an image, use the basic handlers (.txt, .json, .pickle, .npz, ...) | |
# See https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py | |
return None | |
def repeat_fn(src, n_repeats=5): | |
""" | |
Repeat each sample n_repeats times. | |
E.g. A B C ... repeated 3 times becomes A A A B B B C C C ... | |
Depending on the downstream application, a shuffle should be added after this. | |
""" | |
for sample in src: | |
for _ in range(n_repeats): | |
yield sample | |
def remove_extensions(sample): | |
""" | |
In webdatasets, we identify the type of a given modality by adding an extension | |
in the form f"{modality_name}.{modality_extension}", e.g. "rgb.jpg" or "caption.json". | |
This function removes them and returns a dictionary of {f"{modality_name}": modality}. | |
""" | |
return {remove_ext_with_gz(k): v for k, v in sample.items()} | |
def filter_metadata(sample, metadata=['__key__', '__url__', 'file_name', 'class_name', 'class_idx']): | |
""" Filters out non-modality entries specified in metadata when loading tar files with webdatasets. """ | |
return {k: v for k, v in sample.items() if k not in metadata} | |
def apply_modality_transforms(sample, modality_transforms): | |
""" Applies a dictionary of modality-specific transforms to a dictionary of modalities. """ | |
return {k: (modality_transforms[get_transform_key(k)](v) if k in modality_transforms else v) for k, v in sample.items() } | |
def tok_to_int64(sample): | |
""" | |
Pre-computed tokens are saved as int16, but we need them as int64 instead. | |
""" | |
return {k: (v.astype('int64') if 'tok_' in k else v) for k, v in sample.items()} | |
def rename_modalities(sample, modality_paths): | |
""" | |
Renames modalities to their corresponding names in modality_paths. | |
""" | |
return {out_path: sample[loaded_path] for out_path, loaded_path in modality_paths.items()} | |
def extract_modality_names(s): | |
# Regular expression pattern to match anything enclosed in '{' and '}', and comma separated | |
pattern = r'\{([^}]*)\}' | |
match = re.search(pattern, s) | |
return match.group(1).split(',') if match else [] | |
def identity(sample): | |
""" Identity function that does nothing. """ | |
return sample | |
def multi_tarfile_samples(src_iter: Iterable[Dict[str, Any]], | |
modality_name_map: Dict[str, str] = None, | |
handler: Callable[[Exception], bool] = warn_and_continue): | |
"""Webdataset does not support splitting up shards by modality, so we need to do this manually. | |
Usually, we would need to save all modalities in the same tar file, e.g. shard_root_train/{00000..12345}.tar, | |
where each shard contains 1000 samples and each sample contains all modalities. | |
This is not flexible when adding new modalities, so we instead save each modality in a separate tar file, | |
e.g. shard_root_train_rgb/{00000..12345}.tar, shard_root_train_caption/{00000..12345}.tar, etc., where each shard contains | |
again 1000 samples, but each sample contains only one modality. All samples in all shards have to be aligned. | |
This function takes an iterator over shard URLs, where we use brace expansion to specify multiple tar files per modality. | |
E.g. shard_root_train_[rgb,caption]/00123.tar will be expanded to shard_root_train_rgb/00123.tar and shard_root_train_caption/00123.tar, | |
and the samples from these two tar files will be combined into a single sample. | |
Args: | |
src_iter: Iterator over shards that *already brace expanded the shard numbers*, | |
e.g. {'url': 'shard_root_train_[rgb,caption]/00000.tar'}, {'url': 'shard_root_train_[rgb,caption]/00001.tar'}, ... | |
This function will also work when no square braces for multiple modalities are used, e.g. {'url': 'shard_root_train/00000.tar'}, ... | |
It can be a drop-in replacement for wds.tarfile_samples. | |
modality_name_map: Optional dictionary specifying a mapping from modality folder names to arbitrary other names. | |
handler: Function that handles exceptions. If it returns True, the shard is skipped. If it returns False, the function exits. | |
Yields: | |
Dictionary of aligned samples from all modalities. | |
""" | |
for src in src_iter: | |
# Multi tar file URLs use brace expansion with square braces | |
multi_tar_urls = src['url'].translate(str.maketrans('[]', '{}')) | |
modality_names = extract_modality_names(multi_tar_urls) | |
if len(modality_names) == 0: | |
# Case where multi-modal braceexpand is not used, e.g. shard_dir/shard00000.tar | |
modality_names = [None] | |
multi_tar_urls = [multi_tar_urls] | |
elif len(modality_names) == 1: | |
# Brace expand doesn't work with a single entry, e.g. shard_dir/[foo]/shard00000.tar | |
multi_tar_urls = [multi_tar_urls.replace("{", "").replace("}", "")] | |
else: | |
# Remaining cases where multiple modalities are specified, e.g. shard_dir/[foo,bar]/shard00000.tar | |
multi_tar_urls = list(braceexpand.braceexpand(multi_tar_urls)) | |
# Create tar iterators for shards of all modalities | |
tar_iters = [wds.tarfile_samples([{'url': tar_url}]) for tar_url in multi_tar_urls] | |
try: | |
# Loop over these iterators in parallel and combine the tar files from different modalities | |
for multi_tar_files in zip(*tar_iters): | |
merged_dict = {} | |
merged_dict['__key__'] = multi_tar_files[0]['__key__'] | |
merged_dict['__url__'] = src['url'] | |
for modality_name, modality_dict in zip(modality_names, multi_tar_files): | |
_key = modality_dict.pop('__key__') | |
_url = modality_dict.pop('__url__') | |
if _key != merged_dict['__key__']: | |
raise ValueError(f"Divergence detected! Trying to merge keys {_key} of {modality_name} and {merged_dict['__key__']} of merged_dict with modalities {merged_dict.keys()}.") | |
tar_is_multimodal = len(modality_dict) > 1 | |
for k, v in modality_dict.items(): | |
if tar_is_multimodal or check_dots(k) or modality_name is None: | |
# We don't change the keys in the following cases: | |
# 1. The shard contains multiple modalities. Then they *have* to follow the idx.modality_id.ext convention | |
# 2. If any key contains a dot, this means it already has the idx.modality_id.ext format (idx. is already removed at this stage) | |
# 3. If the modality name is None, no modality folder was specified (see beginning of function) | |
merged_dict[k] = v | |
else: | |
mapped_name = modality_name if modality_name_map is None else modality_name_map.get(modality_name, modality_name) | |
merged_dict[f'{mapped_name}.{k}'] = v | |
yield merged_dict | |
except Exception as e: | |
print(e) | |
print(f"Exception occurred while processing {src['url']}.") | |
if handler(e): | |
print('Skipping shard...') | |
continue | |
else: | |
break | |
def build_wds_fm_pretraining_dataloader( | |
data_path, all_domains, modality_info, modality_transforms, image_augmenter, | |
text_tokenizer, input_tokens_range, target_tokens_range, | |
num_gpus, num_workers, batch_size, epoch_size, sampling_weights=None, modality_name_map=None, | |
shuffle_buffer_load=1000, shuffle_buffer_repeat=5000, n_repeats=5): | |
"""Builds the WebDataset FourM pre-training dataloader based on the given arguments. | |
Args: | |
data_path: Path to the dataset. | |
all_domains: List of all modalities to be used. | |
modality_info: Dictionary containing information about the modalities. | |
modality_transforms: Dictionary containing the transforms for each modality. | |
image_augmenter: Image augmenter. | |
text_tokenizer: Text tokenizer (for sequence modalities). | |
input_tokens_range: Range of the input token budget. | |
target_tokens_range: Range of the target token budget. | |
num_gpus: Number of GPUs. | |
num_workers: Number of workers. | |
batch_size: Batch size. | |
epoch_size: Number of samples per "epoch". (Here, epoch refers to an interrupted training loop without evaluation or checkpointing). | |
sampling_weights: Sampling weights for the mixture of Dirichlet distributions. | |
modality_name_map: Optional dictionary specifying a mapping from modality folder names to arbitrary other names. | |
shuffle_buffer_load: Shuffle buffer size when loading samples from tar files (first shuffle). | |
shuffle_buffer_repeat: Shuffle buffer size after repeating samples (second shuffle). | |
n_repeats: Number of times to repeat each sample. | |
Returns: | |
FourM pre-training dataloader as a WebDataset DataLoader. | |
""" | |
modality_paths = {mod: modality_info[mod].get('path', None) or mod for mod in modality_info} | |
# Remove vq domains that require a tokenizer | |
modalities_without_vq = [mod for mod in all_domains if not modality_info[mod].get("requires_tokenizer", False)] | |
# If we are using a pre-tokenized modality, we default to pre-computed crop settings | |
if any([modality_info[domain].get("pretokenized", False) for domain in all_domains]): | |
modalities_without_vq.append("crop_settings") | |
modality_transforms = copy.deepcopy(modality_transforms) | |
modality_transforms["crop_settings"] = CropSettingsTransform() | |
modality_paths["crop_settings"] = "crop_settings" | |
# Webdatasets always adds __key__ to the dictionary, so we add a transform that does nothing with it | |
modality_transforms["__key__"] = IdentityTransform() | |
transform = transforms.Compose([ | |
UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter), | |
UnifiedMasking(modality_info=modality_info, text_tokenizer=text_tokenizer, | |
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range, | |
sampling_weights=sampling_weights) | |
]) | |
datapipe = wds.DataPipeline( | |
# Infinitely sample shards from the shard list with replacement. Each worker is seeded independently. | |
wds.ResampledShards(data_path), | |
partial(multi_tarfile_samples, modality_name_map=modality_name_map), # Extract individual samples from single or multi-modal tar files | |
wds.shuffle(shuffle_buffer_load), # Shuffle with a buffer of given size | |
wds.decode(wds_decoder), # Decode from bytes to PIL images, numpy arrays, etc. | |
wds.filters.compose(partial(repeat_fn, n_repeats=n_repeats)), # Repeats each sample n times -> A A A B B B C C C ... | |
wds.shuffle(shuffle_buffer_repeat), # Shuffle again with a buffer of given size | |
wds.map(remove_extensions), # Remove "file extensions" from dictionary keys | |
map(filter_metadata), # Remove non-task keys | |
map(tok_to_int64), # Convert pre-computed tokens to int64 | |
map(partial(rename_modalities, modality_paths=modality_paths)), # Rename modalities to their corresponding names in modality_paths | |
map(transform), # Apply data augmentation and masking | |
wds.batched(batch_size, collation_fn=default_collate, partial=False) | |
if batch_size is not None else map(identity), # Batching | |
) | |
if epoch_size is not None: | |
batch_size_iter = batch_size if batch_size is not None else 1 | |
datapipe = datapipe.with_epoch(epoch_size // (num_gpus * num_workers * batch_size_iter)) # Pre-define iterator length | |
if batch_size is not None: | |
# Perform multi-threaded dataloading | |
return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None) | |
else: | |
return datapipe | |
def build_wds_divae_dataloader( | |
data_path, modality_info, modality_transforms, image_augmenter, | |
num_gpus, num_workers, batch_size, epoch_size, shuffle_buffer_load=1000, | |
shuffle_buffer_repeat=5000, n_repeats=1): | |
modality_paths = {mod: modality_info[mod].get('path', None) or mod for mod in modality_info} | |
# Webdatasets always adds __key__ to the dictionary, so we add a transform that does nothing with it | |
modality_transforms["__key__"] = IdentityTransform() | |
transform = UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter) | |
datapipe = wds.DataPipeline( | |
# Infinitely sample shards from the shard list with replacement. Each worker is seeded independently. | |
wds.ResampledShards(data_path), | |
multi_tarfile_samples, # Extract individual samples from single or multi-modal tar files | |
wds.shuffle(shuffle_buffer_load), # Shuffle with a buffer of given size | |
wds.decode(wds_decoder), # Decode from bytes to PIL images, numpy arrays, etc. | |
wds.filters.compose(partial(repeat_fn, n_repeats=n_repeats)), # Repeats each sample n times -> A A A B B B C C C ... | |
wds.shuffle(shuffle_buffer_repeat), # Shuffle again with a buffer of given size | |
map(remove_extensions), # Remove "file extensions" from dictionary keys | |
map(filter_metadata), # Remove non-task keys | |
map(tok_to_int64), # Convert pre-computed tokens to int64 | |
map(partial(rename_modalities, modality_paths=modality_paths)), # Rename modalities to their corresponding names in modality_paths | |
map(transform), # Apply data augmentation and masking | |
wds.batched(batch_size, collation_fn=default_collate, partial=False) | |
if batch_size is not None else map(identity), # Batching | |
) | |
if epoch_size is not None: | |
batch_size_iter = batch_size if batch_size is not None else 1 | |
datapipe = datapipe.with_epoch(epoch_size // (num_gpus * num_workers * batch_size_iter)) # Pre-define iterator length | |
if batch_size is not None: | |
# Perform multi-threaded dataloading | |
return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None) | |
else: | |
return datapipe | |
### Huggingface datasets functions | |
def text_to_caption(sample): | |
""" Rename "text" to "caption". """ | |
return {'caption': sample['text']} | |
def build_huggingface_pretraining_dataloader( | |
data_path, all_domains, modality_info, modality_transforms, image_augmenter, | |
text_tokenizer, input_tokens_range, target_tokens_range, | |
num_gpus, num_workers, batch_size, epoch_size, split, | |
streaming=True, rename_text_to_caption=True, shuffle_buffer_load=10_000, shuffle_seed=0): | |
# Load huggingface dataset and split samples across workers. Shuffle samples in each worker | |
dataset = load_dataset(data_path, split=split, streaming=streaming) | |
dataset = split_dataset_by_node(dataset, rank=get_rank(), world_size=get_world_size()) | |
dataset = dataset.shuffle(seed=shuffle_seed, buffer_size=shuffle_buffer_load) | |
modality_info = {mod: modality_info[mod] for mod in modality_info if mod in all_domains} | |
transform = transforms.Compose([ | |
UnifiedDataTransform(transforms_dict=modality_transforms, image_augmenter=image_augmenter), | |
UnifiedMasking(modality_info=modality_info, text_tokenizer=text_tokenizer, | |
input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range) | |
]) | |
datapipe = wds.DataPipeline( | |
dataset, | |
map(text_to_caption) if rename_text_to_caption else map(identity), # Rename "text" to "caption" | |
map(filter_metadata), # Remove non-task keys | |
map(transform), # Apply data augmentation and masking | |
wds.batched(batch_size, collation_fn=default_collate, partial=False) | |
if batch_size is not None else map(identity), # Batching | |
) | |
datapipe.n_shards = dataset.n_shards | |
num_workers = min(num_workers, dataset.n_shards) | |
if epoch_size is not None: | |
batch_size_iter = batch_size if batch_size is not None else 1 | |
datapipe = datapipe.with_epoch(epoch_size // (num_gpus * num_workers * batch_size_iter)) # Pre-define iterator length | |
if batch_size is not None: | |
# Perform multi-threaded dataloading | |
return wds.WebLoader(datapipe, num_workers=num_workers, batch_size=None) | |
else: | |
return datapipe | |
### Multi-dataset loading utils | |
def make_empty_mod_dict(modality_info): | |
empty_mod_dicts = {} | |
for mod_name, mod_info in modality_info.items(): | |
empty_mod = {} | |
# Tensor | |
if 'num_channels' in mod_info and 'input_size' in mod_info: | |
# Handle image-like modalities | |
max_tokens = mod_info['max_tokens'] | |
empty_mod['tensor'] = torch.zeros((mod_info['num_channels'], mod_info['input_size'], mod_info['input_size']), dtype=torch.float32) | |
elif mod_name == 't5_caption': | |
# Handle T5 embedding | |
max_tokens = mod_info['max_tokens'] | |
orig_emb_dim = mod_info['encoder_embedding']().orig_emb_dim | |
empty_mod['tensor'] = torch.zeros((max_tokens, orig_emb_dim), dtype=torch.float32) | |
elif mod_info['type'] in ['seq', 'seq_emb', 'seq_token']: | |
# Handle all other discrete sequence modalities | |
max_tokens = (mod_info['max_tokens'] + 1) * 2 | |
empty_mod['tensor'] = torch.zeros((max_tokens), dtype=torch.int32) | |
else: | |
max_tokens = mod_info['max_tokens'] | |
empty_mod['tensor'] = torch.zeros((max_tokens), dtype=torch.int32) | |
# Input and target masks | |
empty_mod['input_mask'] = torch.ones((max_tokens), dtype=torch.bool) | |
empty_mod['target_mask'] = torch.ones((max_tokens), dtype=torch.bool) | |
# Decoder attention mask | |
empty_mod['decoder_attention_mask'] = torch.zeros((max_tokens), dtype=torch.int32) | |
empty_mod_dicts[mod_name] = empty_mod | |
return empty_mod_dicts | |
class MixtureDataset(IterableDataset): | |
def __init__(self, data_iters, weights, modality_info): | |
self.orig_data_iters = data_iters | |
self.data_iters = [iter(data_iter) for data_iter in data_iters] # Create initial iterators | |
self.sampling_probs = np.array(weights) / sum(weights) | |
self.modality_info = modality_info | |
def reset_iterator(self, idx): | |
""" Reset the iterator when exhausted. """ | |
self.data_iters[idx] = iter(self.orig_data_iters[idx]) | |
def __iter__(self): | |
while True: | |
dataset_idx = np.random.choice(len(self.sampling_probs), p=self.sampling_probs) | |
try: | |
data = next(self.data_iters[dataset_idx]) | |
except StopIteration: # If the iterator is exhausted | |
self.reset_iterator(dataset_idx) # Reset it | |
data = next(self.data_iters[dataset_idx]) | |
mod_dict = make_empty_mod_dict(self.modality_info) | |
mod_dict.update(data) | |
yield mod_dict | |
def build_mixture_dataloader(data_iters, weights, modality_info, batch_size, num_workers, epoch_size, num_gpus): | |
mixture_pipe = wds.DataPipeline( | |
MixtureDataset(data_iters, weights, modality_info), | |
wds.batched(batch_size, collation_fn=default_collate, partial=False), | |
).with_epoch(epoch_size // (num_gpus * num_workers * batch_size)) # Pre-define iterator length | |
mixture_loader = wds.WebLoader(mixture_pipe, num_workers=num_workers, batch_size=None) | |
return mixture_loader | |