File size: 12,564 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import copy
import os
import random
from operator import itemgetter
from typing import Optional, List

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributed as dist
from PIL import Image
from torch.utils.data import Dataset, Sampler
from torch.utils.data import Sampler, DistributedSampler


def chunk_indices(indices: list[int], size: int) -> tuple[torch.Tensor, ...]:
    return torch.split(torch.tensor(indices), size)


class CombinedDataLoader:
    def __init__(self, dataloaders, reinit=True):
        """
        :param dataloaders: list of pytorch dataloaders
        """
        self.dataloaders = dataloaders
        self.reinit = reinit
        self.dataloader_idx = 0
        self.loader_iters = [iter(dataloader) for dataloader in self.dataloaders]

    def __iter__(self):
        return self

    def __next__(self):
        # Choose a dataloader based on weights
        chosen_loader_iter = self.loader_iters[self.dataloader_idx]

        try:
            data = next(chosen_loader_iter)
            return data
        except StopIteration:
            # Handle case where a dataloader is exhausted. Reinitialize the iterator.
            self.dataloader_idx = self.dataloader_idx + 1
            if self.dataloader_idx == len(self.loader_iters):
                self.dataloader_idx = 0  # reset
                raise StopIteration
            return self.__next__()

    def __len__(self):
        return sum([len(dataloader) for dataloader in self.dataloaders])


class CombinedBatchSampler(torch.utils.data.Sampler):
    # For validation dataloaders.
    def __init__(self, datasets, batch_size, num_processes=1, shuffle=False):
        super().__init__()  # no-op
        prev_idx = 0
        all_batches = []

        for dataset in datasets:
            indices = list(range(prev_idx, prev_idx + len(dataset)))
            if shuffle:
                random.shuffle(indices)

            # exclude remainer, if necessary
            remainder = len(indices) % (batch_size * num_processes)
            if remainder > 0:
                indices = indices[:-remainder]  # exclude last

            chunk_i = chunk_indices(indices, batch_size)  # equally sized
            all_batches += chunk_i

            # add the new indices without the last batch
            prev_idx += len(chunk_i) * batch_size  # len(dataset)

        if shuffle:
            random.shuffle(all_batches)

        self.all_batches = all_batches

    def __iter__(self):
        return iter(self.all_batches)

    def __len__(self):
        return len(self.all_batches)


# https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py
class DatasetFromSampler(Dataset):
    """Dataset to create indexes from `Sampler`.

    Args:
        sampler: PyTorch sampler
    """

    def __init__(self, sampler: Sampler):
        """Initialisation for DatasetFromSampler."""
        self.sampler = sampler
        self.sampler_list = None

    def __getitem__(self, index: int):
        """Gets element of the dataset.

        Args:
            index: index of the element in the dataset

        Returns:
            Single element by index
        """
        if self.sampler_list is None:
            self.sampler_list = list(self.sampler)
        return self.sampler_list[index]

    def __len__(self) -> int:
        """
        Returns:
            int: length of the dataset
        """
        return len(self.sampler)


class DistributedSamplerWrapper(DistributedSampler):
    """
    Wrapper over `Sampler` for distributed training.
    Allows you to use any sampler in distributed mode.

    It is especially useful in conjunction with
    `torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSamplerWrapper instance as a DataLoader
    sampler, and load a subset of subsampled data of the original dataset
    that is exclusive to it.

    .. note::
        Sampler is assumed to be of constant size.
    """

    def __init__(
        self,
        sampler,
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        shuffle: bool = True,
    ):
        """

        Args:
            sampler: Sampler used for subsampling
            num_replicas (int, optional): Number of processes participating in
                distributed training
            rank (int, optional): Rank of the current process
                within ``num_replicas``
            shuffle (bool, optional): If true (default),
                sampler will shuffle the indices
        """
        super(DistributedSamplerWrapper, self).__init__(
            DatasetFromSampler(sampler),
            num_replicas=num_replicas,
            rank=rank,
            shuffle=shuffle,
        )
        self.sampler = sampler

    def __iter__(self):
        """Iterate over sampler.

        Returns:
            python iterator
        """
        self.dataset = DatasetFromSampler(self.sampler)
        indexes_of_indexes = super().__iter__()
        subsampler_indexes = self.dataset
        return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))


# https://github.com/rabeehk/hyperformer/blob/main/hyperformer/data/multitask_sampler.py
class MultiTaskBatchSampler(Sampler):
    """Defines a sampler to sample multiple datasets with temperature sampling
    in a distributed fashion."""

    def __init__(
        self,
        dataset_sizes: List[int],
        batch_size: int,
        temperature: float,
        dataset_groups=[],
        num_replicas: Optional[int] = 1,
        rank: Optional[int] = 0,
        seed: int = 0,
        shuffle: bool = True,
        shuffle_task: bool = True,
    ) -> None:
        """Constructor for MultiTaskBatchSampler.
        Args:
            dataset_sizes: a list of integers, specifies the number of samples in
                each dataset.
            batch_size: integer, specifies the batch size.
            temperature: float, temperature used for temperature sampling. The larger
                the value, the datasets are sampled equally, and for value of 0, the datasets
                will be sampled according to their number of samples.
            num_replicas: integer, specifies the number of processes.
            rank: integer, specifies the rank of the current process/
            seed: integer, random seed.
            shuffle: bool, if set to true, the datasets will be shuffled in each epoch.
        """

        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
            print("data sampler rank:", rank)

        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval" " [0, {}]".format(rank, num_replicas - 1)
            )

        self.dataset_groups = dataset_groups
        print("dataset groups:", self.dataset_groups)

        self.num_replicas = num_replicas
        self.shuffle_task = shuffle_task
        self.rank = rank
        self.batch_size = batch_size
        self.dataset_sizes = dataset_sizes

        # By default we drop the last elements if dataset is not divisible by the number of ranks.
        self.rank_dataset_sizes = [dataset_size // self.num_replicas for dataset_size in self.dataset_sizes]
        self.dataset_offsets = torch.cumsum(torch.LongTensor([0] + dataset_sizes), 0)
        self.total_sizes = [
            (dataset_size // self.num_replicas) * self.num_replicas for dataset_size in self.dataset_sizes
        ]
        self.temperature = temperature
        self.seed = seed
        self.epoch = 0
        self.num_batches_per_epoch = (
            (np.sum(dataset_sizes) + self.batch_size - 1) // self.batch_size // self.num_replicas
        )
        self.shuffle = shuffle
        print(f"{num_replicas=} {rank=} {self.num_batches_per_epoch=} {self.total_sizes=} self.weights={self.generate_tasks_distribution()}")

    def generate_tasks_distribution(self):
        """Given the dataset sizes computes the weights to sample each dataset
        according to the temperature sampling."""
        if len(self.dataset_groups) > 0:
            # normalize across groups first
            weights = []
            num_groups = len(self.dataset_groups)
            for group in self.dataset_groups:
                lo, hi = group
                dataset_sizes = [self.dataset_sizes[idx] for idx in range(lo, hi)]
                total_size = sum(dataset_sizes)
                group_weights = np.array([(size / total_size) ** (1.0 / self.temperature) for size in dataset_sizes])
                group_weights = group_weights / np.sum(group_weights) / num_groups
                weights = np.concatenate((weights, group_weights))

        else:
            total_size = sum(self.dataset_sizes)
            weights = np.array([(size / total_size) ** (1.0 / self.temperature) for size in self.dataset_sizes])
            weights = weights / np.sum(weights)
        return torch.as_tensor(weights, dtype=torch.double)

    def __iter__(self):
        # Defines torch generator, to make random choices consistent across cores in
        # different epochs, the seed needs to be set based on seed and epoch.
        generator = torch.Generator()
        generator.manual_seed(self.seed + self.epoch)

        # Shuffles the datasets if shuffle is set to true.
        indices = []
        for dataset_size in self.dataset_sizes:
            if self.shuffle:
                indices.append(torch.randperm(dataset_size, generator=generator).tolist())
            else:
                indices.append(list(range(dataset_size)))

        # Shards the datasets across the all processes.
        self.rank_indices = []
        for i in range(len(self.dataset_sizes)):
            self.rank_indices.append(indices[i][self.rank : self.total_sizes[i] : self.num_replicas])

        # To make the model consistent across different processes, since the
        # model is based on tasks, we need to make sure the same task is selected
        # across different processes.
        tasks_distribution: torch.Tensor = self.generate_tasks_distribution()

        # Chooses the tasks which will be used in each batch in one epoch.
        # With passing generator, we make sure this choice is consistent across
        # different processes.

        # want them to be different.
        if self.shuffle_task:
            generator.manual_seed(self.seed + self.epoch + self.rank)
        batch_task_assignments = torch.multinomial(
            tasks_distribution, self.num_batches_per_epoch, replacement=True, generator=generator
        )

        for batch_task in batch_task_assignments:
            # Gets the number of samples of the selected datasets available for the current rank.
            num_task_samples = self.rank_dataset_sizes[batch_task]
            # Computes the random samples from the chosen dataset.
            indices = torch.randint(low=0, high=num_task_samples, size=(self.batch_size,), generator=generator).tolist()
            # Converts the selected indices to the global indices on the given dataset.
            results = (self.dataset_offsets[batch_task] + torch.tensor(self.rank_indices[batch_task])[indices]).tolist()
            yield results

    def __len__(self):
        return self.num_batches_per_epoch

    def set_epoch(self, epoch):
        self.epoch = epoch

def make_dataset_pie_plot(domains, traj_nums):
    """draw the dataset mixture as a pie plot"""
    new_domains = []
    for idx, domain in enumerate(domains):
        new_domains.append(domain)
    plt.cla()
    fig1, ax1 = plt.subplots(figsize=(40, 40))
    traj_prob = np.array(traj_nums) / np.sum(traj_nums)
    tab20 = plt.get_cmap("tab20").colors
    tab20b = plt.get_cmap("tab20b").colors
    tab20c = plt.get_cmap("tab20c").colors

    # Combine them to get 60 distinct colors
    colors = tab20 + tab20b + tab20c
    patches, _ = ax1.pie(traj_prob, startangle=90, colors=colors[: len(traj_prob)])
    ax1.axis("equal")
    ax1.legend(patches, new_domains, loc="center left", bbox_to_anchor=(0.8, 0.5), prop={"size": 32})
    fig1.canvas.draw()

    return Image.frombytes("RGB", fig1.canvas.get_width_height(), fig1.canvas.tostring_rgb())