Spaces:
Build error
Build error
from __future__ import annotations | |
import random | |
from typing import Optional | |
import torchdata.datapipes.iter | |
import webdataset as wds | |
from omegaconf import DictConfig | |
from pytorch_lightning import LightningDataModule | |
from torch.utils.data import DataLoader, Dataset | |
from .subsets import NuScenesDataset, YouTubeDataset | |
try: | |
from sdata import create_dataset, create_dummy_dataset, create_loader | |
except ImportError: | |
print("#" * 100) | |
print("Datasets not yet available") | |
print("To enable, we need to add stable-datasets as a submodule") | |
print("Please use ``git submodule update --init --recursive``") | |
print("and do ``pip install -e stable-datasets/`` from the root of this repo") | |
print("#" * 100) | |
exit(1) | |
class StableDataModuleFromConfig(LightningDataModule): | |
def __init__( | |
self, | |
train: DictConfig, | |
validation: Optional[DictConfig] = None, | |
test: Optional[DictConfig] = None, | |
skip_val_loader: bool = False, | |
dummy: bool = False | |
): | |
super().__init__() | |
self.train_config = train | |
assert ( | |
"datapipeline" in self.train_config and "loader" in self.train_config | |
), "Train config requires the fields `datapipeline` and `loader`" | |
self.val_config = validation | |
if not skip_val_loader: | |
if self.val_config is not None: | |
assert ( | |
"datapipeline" in self.val_config and "loader" in self.val_config | |
), "Validation config requires the fields `datapipeline` and `loader`" | |
else: | |
print( | |
"WARNING: no validation datapipeline defined, using that one from training" | |
) | |
self.val_config = train | |
self.test_config = test | |
if self.test_config is not None: | |
assert ( | |
"datapipeline" in self.test_config and "loader" in self.test_config | |
), "Test config requires the fields `datapipeline` and `loader`" | |
self.dummy = dummy | |
if self.dummy: | |
print("#" * 100) | |
print("Using dummy dataset, hope you are debugging") | |
print("#" * 100) | |
def setup(self, stage: str) -> None: | |
print("Preparing datasets") | |
if self.dummy: | |
data_fn = create_dummy_dataset | |
else: | |
data_fn = create_dataset | |
self.train_data_pipeline = data_fn(**self.train_config.datapipeline) | |
if self.val_config: | |
self.val_data_pipeline = data_fn(**self.val_config.datapipeline) | |
if self.test_config: | |
self.test_data_pipeline = data_fn(**self.test_config.datapipeline) | |
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: | |
return create_loader(self.train_data_pipeline, **self.train_config.loader) | |
def val_dataloader(self) -> wds.DataPipeline: | |
return create_loader(self.val_data_pipeline, **self.val_config.loader) | |
def test_dataloader(self) -> wds.DataPipeline: | |
return create_loader(self.test_data_pipeline, **self.test_config.loader) | |
def dataset_mapping(subset_list: list, target_height: int, target_width: int, num_frames: int): | |
datasets = list() | |
for subset_name in subset_list: | |
if subset_name == "YouTube": | |
datasets.append( | |
YouTubeDataset(target_height=target_height, target_width=target_width, num_frames=num_frames) | |
) | |
elif subset_name == "NuScenes": | |
datasets.append( | |
NuScenesDataset(target_height=target_height, target_width=target_width, num_frames=num_frames) | |
) | |
else: | |
raise NotImplementedError(f"Please define {subset_name} as a subset") | |
return datasets | |
class MultiSourceSamplerDataset(Dataset): | |
def __init__(self, subsets, probs, samples_per_epoch=1000, target_height=320, target_width=576, num_frames=25): | |
self.subsets = dataset_mapping(subsets, target_height, target_width, num_frames) | |
# if probabilities not provided, sample uniformly from all samples | |
if probs is None: | |
probs = [len(d) for d in self.subsets] | |
# normalize | |
total_prob = sum(probs) | |
self.sample_probs = [x / total_prob for x in probs] | |
self.samples_per_epoch = samples_per_epoch | |
def __len__(self): | |
return self.samples_per_epoch | |
def __getitem__(self, index): | |
""" | |
Args: | |
---- | |
index (int): Index (ignored since we sample randomly). | |
Returns: | |
------- | |
TensorDict: Dict containing all the data blocks. | |
""" | |
# randomly select a subset based on weights | |
subset = random.choices(self.subsets, self.sample_probs)[0] | |
# sample a valid sample with a random index | |
while True: | |
try: | |
sample_item = random.choice(subset) | |
# return the sampled item | |
return sample_item | |
except: | |
pass | |
class Sampler(LightningDataModule): | |
def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True, subsets=None, probs=None, | |
samples_per_epoch=None, target_height=320, target_width=576, num_frames=25): | |
super().__init__() | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 | |
self.shuffle = shuffle | |
self.train_dataset = MultiSourceSamplerDataset( | |
subsets=subsets, probs=probs, samples_per_epoch=samples_per_epoch, | |
target_height=target_height, target_width=target_width, num_frames=num_frames | |
) | |
def prepare_data(self): | |
pass | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
num_workers=self.num_workers, | |
prefetch_factor=self.prefetch_factor | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
self.train_dataset, # we disable online testing to improve training efficiency | |
batch_size=self.batch_size, | |
shuffle=False, | |
num_workers=self.num_workers, | |
prefetch_factor=self.prefetch_factor | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
num_workers=self.num_workers, | |
prefetch_factor=self.prefetch_factor | |
) | |