File size: 6,613 Bytes
d323598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from __future__ import annotations

import random
from typing import Optional

import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from .subsets import NuScenesDataset, YouTubeDataset

try:
    from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError:
    print("#" * 100)
    print("Datasets not yet available")
    print("To enable, we need to add stable-datasets as a submodule")
    print("Please use ``git submodule update --init --recursive``")
    print("and do ``pip install -e stable-datasets/`` from the root of this repo")
    print("#" * 100)
    exit(1)


class StableDataModuleFromConfig(LightningDataModule):
    def __init__(
            self,
            train: DictConfig,
            validation: Optional[DictConfig] = None,
            test: Optional[DictConfig] = None,
            skip_val_loader: bool = False,
            dummy: bool = False
    ):
        super().__init__()
        self.train_config = train
        assert (
                "datapipeline" in self.train_config and "loader" in self.train_config
        ), "Train config requires the fields `datapipeline` and `loader`"

        self.val_config = validation
        if not skip_val_loader:
            if self.val_config is not None:
                assert (
                        "datapipeline" in self.val_config and "loader" in self.val_config
                ), "Validation config requires the fields `datapipeline` and `loader`"
            else:
                print(
                    "WARNING: no validation datapipeline defined, using that one from training"
                )
                self.val_config = train

        self.test_config = test
        if self.test_config is not None:
            assert (
                    "datapipeline" in self.test_config and "loader" in self.test_config
            ), "Test config requires the fields `datapipeline` and `loader`"

        self.dummy = dummy
        if self.dummy:
            print("#" * 100)
            print("Using dummy dataset, hope you are debugging")
            print("#" * 100)

    def setup(self, stage: str) -> None:
        print("Preparing datasets")
        if self.dummy:
            data_fn = create_dummy_dataset
        else:
            data_fn = create_dataset

        self.train_data_pipeline = data_fn(**self.train_config.datapipeline)
        if self.val_config:
            self.val_data_pipeline = data_fn(**self.val_config.datapipeline)
        if self.test_config:
            self.test_data_pipeline = data_fn(**self.test_config.datapipeline)

    def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
        return create_loader(self.train_data_pipeline, **self.train_config.loader)

    def val_dataloader(self) -> wds.DataPipeline:
        return create_loader(self.val_data_pipeline, **self.val_config.loader)

    def test_dataloader(self) -> wds.DataPipeline:
        return create_loader(self.test_data_pipeline, **self.test_config.loader)


def dataset_mapping(subset_list: list, target_height: int, target_width: int, num_frames: int):
    datasets = list()
    for subset_name in subset_list:
        if subset_name == "YouTube":
            datasets.append(
                YouTubeDataset(target_height=target_height, target_width=target_width, num_frames=num_frames)
            )
        elif subset_name == "NuScenes":
            datasets.append(
                NuScenesDataset(target_height=target_height, target_width=target_width, num_frames=num_frames)
            )
        else:
            raise NotImplementedError(f"Please define {subset_name} as a subset")
    return datasets


class MultiSourceSamplerDataset(Dataset):
    def __init__(self, subsets, probs, samples_per_epoch=1000, target_height=320, target_width=576, num_frames=25):
        self.subsets = dataset_mapping(subsets, target_height, target_width, num_frames)
        # if probabilities not provided, sample uniformly from all samples
        if probs is None:
            probs = [len(d) for d in self.subsets]
        # normalize
        total_prob = sum(probs)
        self.sample_probs = [x / total_prob for x in probs]
        self.samples_per_epoch = samples_per_epoch

    def __len__(self):
        return self.samples_per_epoch

    def __getitem__(self, index):
        """
        Args:
        ----
            index (int): Index (ignored since we sample randomly).

        Returns:
        -------
            TensorDict: Dict containing all the data blocks.

        """

        # randomly select a subset based on weights
        subset = random.choices(self.subsets, self.sample_probs)[0]

        # sample a valid sample with a random index
        while True:
            try:
                sample_item = random.choice(subset)
                # return the sampled item
                return sample_item
            except:
                pass


class Sampler(LightningDataModule):
    def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True, subsets=None, probs=None,
                 samples_per_epoch=None, target_height=320, target_width=576, num_frames=25):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
        self.shuffle = shuffle
        self.train_dataset = MultiSourceSamplerDataset(
            subsets=subsets, probs=probs, samples_per_epoch=samples_per_epoch,
            target_height=target_height, target_width=target_width, num_frames=num_frames
        )

    def prepare_data(self):
        pass

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            prefetch_factor=self.prefetch_factor
        )

    def test_dataloader(self):
        return DataLoader(
            self.train_dataset,  # we disable online testing to improve training efficiency
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            prefetch_factor=self.prefetch_factor
        )

    def val_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            prefetch_factor=self.prefetch_factor
        )