Spaces:
Build error
Build error
File size: 6,613 Bytes
d323598 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
from __future__ import annotations
import random
from typing import Optional
import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from .subsets import NuScenesDataset, YouTubeDataset
try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError:
print("#" * 100)
print("Datasets not yet available")
print("To enable, we need to add stable-datasets as a submodule")
print("Please use ``git submodule update --init --recursive``")
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
print("#" * 100)
exit(1)
class StableDataModuleFromConfig(LightningDataModule):
def __init__(
self,
train: DictConfig,
validation: Optional[DictConfig] = None,
test: Optional[DictConfig] = None,
skip_val_loader: bool = False,
dummy: bool = False
):
super().__init__()
self.train_config = train
assert (
"datapipeline" in self.train_config and "loader" in self.train_config
), "Train config requires the fields `datapipeline` and `loader`"
self.val_config = validation
if not skip_val_loader:
if self.val_config is not None:
assert (
"datapipeline" in self.val_config and "loader" in self.val_config
), "Validation config requires the fields `datapipeline` and `loader`"
else:
print(
"WARNING: no validation datapipeline defined, using that one from training"
)
self.val_config = train
self.test_config = test
if self.test_config is not None:
assert (
"datapipeline" in self.test_config and "loader" in self.test_config
), "Test config requires the fields `datapipeline` and `loader`"
self.dummy = dummy
if self.dummy:
print("#" * 100)
print("Using dummy dataset, hope you are debugging")
print("#" * 100)
def setup(self, stage: str) -> None:
print("Preparing datasets")
if self.dummy:
data_fn = create_dummy_dataset
else:
data_fn = create_dataset
self.train_data_pipeline = data_fn(**self.train_config.datapipeline)
if self.val_config:
self.val_data_pipeline = data_fn(**self.val_config.datapipeline)
if self.test_config:
self.test_data_pipeline = data_fn(**self.test_config.datapipeline)
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
return create_loader(self.train_data_pipeline, **self.train_config.loader)
def val_dataloader(self) -> wds.DataPipeline:
return create_loader(self.val_data_pipeline, **self.val_config.loader)
def test_dataloader(self) -> wds.DataPipeline:
return create_loader(self.test_data_pipeline, **self.test_config.loader)
def dataset_mapping(subset_list: list, target_height: int, target_width: int, num_frames: int):
datasets = list()
for subset_name in subset_list:
if subset_name == "YouTube":
datasets.append(
YouTubeDataset(target_height=target_height, target_width=target_width, num_frames=num_frames)
)
elif subset_name == "NuScenes":
datasets.append(
NuScenesDataset(target_height=target_height, target_width=target_width, num_frames=num_frames)
)
else:
raise NotImplementedError(f"Please define {subset_name} as a subset")
return datasets
class MultiSourceSamplerDataset(Dataset):
def __init__(self, subsets, probs, samples_per_epoch=1000, target_height=320, target_width=576, num_frames=25):
self.subsets = dataset_mapping(subsets, target_height, target_width, num_frames)
# if probabilities not provided, sample uniformly from all samples
if probs is None:
probs = [len(d) for d in self.subsets]
# normalize
total_prob = sum(probs)
self.sample_probs = [x / total_prob for x in probs]
self.samples_per_epoch = samples_per_epoch
def __len__(self):
return self.samples_per_epoch
def __getitem__(self, index):
"""
Args:
----
index (int): Index (ignored since we sample randomly).
Returns:
-------
TensorDict: Dict containing all the data blocks.
"""
# randomly select a subset based on weights
subset = random.choices(self.subsets, self.sample_probs)[0]
# sample a valid sample with a random index
while True:
try:
sample_item = random.choice(subset)
# return the sampled item
return sample_item
except:
pass
class Sampler(LightningDataModule):
def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True, subsets=None, probs=None,
samples_per_epoch=None, target_height=320, target_width=576, num_frames=25):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
self.shuffle = shuffle
self.train_dataset = MultiSourceSamplerDataset(
subsets=subsets, probs=probs, samples_per_epoch=samples_per_epoch,
target_height=target_height, target_width=target_width, num_frames=num_frames
)
def prepare_data(self):
pass
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor
)
def test_dataloader(self):
return DataLoader(
self.train_dataset, # we disable online testing to improve training efficiency
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor
)
def val_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor
)
|