Vista / vista /vwm /data /dataset.py
Leonard Bruns
Add Vista example
d323598
from __future__ import annotations
import random
from typing import Optional
import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from .subsets import NuScenesDataset, YouTubeDataset
try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError:
print("#" * 100)
print("Datasets not yet available")
print("To enable, we need to add stable-datasets as a submodule")
print("Please use ``git submodule update --init --recursive``")
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
print("#" * 100)
exit(1)
class StableDataModuleFromConfig(LightningDataModule):
def __init__(
self,
train: DictConfig,
validation: Optional[DictConfig] = None,
test: Optional[DictConfig] = None,
skip_val_loader: bool = False,
dummy: bool = False
):
super().__init__()
self.train_config = train
assert (
"datapipeline" in self.train_config and "loader" in self.train_config
), "Train config requires the fields `datapipeline` and `loader`"
self.val_config = validation
if not skip_val_loader:
if self.val_config is not None:
assert (
"datapipeline" in self.val_config and "loader" in self.val_config
), "Validation config requires the fields `datapipeline` and `loader`"
else:
print(
"WARNING: no validation datapipeline defined, using that one from training"
)
self.val_config = train
self.test_config = test
if self.test_config is not None:
assert (
"datapipeline" in self.test_config and "loader" in self.test_config
), "Test config requires the fields `datapipeline` and `loader`"
self.dummy = dummy
if self.dummy:
print("#" * 100)
print("Using dummy dataset, hope you are debugging")
print("#" * 100)
def setup(self, stage: str) -> None:
print("Preparing datasets")
if self.dummy:
data_fn = create_dummy_dataset
else:
data_fn = create_dataset
self.train_data_pipeline = data_fn(**self.train_config.datapipeline)
if self.val_config:
self.val_data_pipeline = data_fn(**self.val_config.datapipeline)
if self.test_config:
self.test_data_pipeline = data_fn(**self.test_config.datapipeline)
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
return create_loader(self.train_data_pipeline, **self.train_config.loader)
def val_dataloader(self) -> wds.DataPipeline:
return create_loader(self.val_data_pipeline, **self.val_config.loader)
def test_dataloader(self) -> wds.DataPipeline:
return create_loader(self.test_data_pipeline, **self.test_config.loader)
def dataset_mapping(subset_list: list, target_height: int, target_width: int, num_frames: int):
datasets = list()
for subset_name in subset_list:
if subset_name == "YouTube":
datasets.append(
YouTubeDataset(target_height=target_height, target_width=target_width, num_frames=num_frames)
)
elif subset_name == "NuScenes":
datasets.append(
NuScenesDataset(target_height=target_height, target_width=target_width, num_frames=num_frames)
)
else:
raise NotImplementedError(f"Please define {subset_name} as a subset")
return datasets
class MultiSourceSamplerDataset(Dataset):
def __init__(self, subsets, probs, samples_per_epoch=1000, target_height=320, target_width=576, num_frames=25):
self.subsets = dataset_mapping(subsets, target_height, target_width, num_frames)
# if probabilities not provided, sample uniformly from all samples
if probs is None:
probs = [len(d) for d in self.subsets]
# normalize
total_prob = sum(probs)
self.sample_probs = [x / total_prob for x in probs]
self.samples_per_epoch = samples_per_epoch
def __len__(self):
return self.samples_per_epoch
def __getitem__(self, index):
"""
Args:
----
index (int): Index (ignored since we sample randomly).
Returns:
-------
TensorDict: Dict containing all the data blocks.
"""
# randomly select a subset based on weights
subset = random.choices(self.subsets, self.sample_probs)[0]
# sample a valid sample with a random index
while True:
try:
sample_item = random.choice(subset)
# return the sampled item
return sample_item
except:
pass
class Sampler(LightningDataModule):
def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True, subsets=None, probs=None,
samples_per_epoch=None, target_height=320, target_width=576, num_frames=25):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
self.shuffle = shuffle
self.train_dataset = MultiSourceSamplerDataset(
subsets=subsets, probs=probs, samples_per_epoch=samples_per_epoch,
target_height=target_height, target_width=target_width, num_frames=num_frames
)
def prepare_data(self):
pass
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor
)
def test_dataloader(self):
return DataLoader(
self.train_dataset, # we disable online testing to improve training efficiency
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor
)
def val_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor
)