File size: 3,988 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from torch.utils.data import DistributedSampler as _DistributedSampler
from typing import Optional, Union, Iterator
import numpy as np


class DistributedSampler(_DistributedSampler):
    """
    A custom distributed sampler that supports shuffling, round-up of the sample size, 
    and ensures deterministic shuffling across epochs.

    Args:
        dataset: The dataset from which samples are drawn.
        num_replicas: Optional; the number of processes participating in the distributed training.
        rank: Optional; the rank of the current process among num_replicas.
        shuffle: Optional; whether to shuffle the dataset every epoch. Defaults to True.
        round_up: Optional; whether to round up the total size to make it divisible among replicas.
                  Defaults to True.

    Attributes:
        shuffle (bool): Whether to shuffle the dataset.
        round_up (bool): Whether to round up the total size to make it evenly divisible among replicas.
        total_size (int): The total number of samples.
    """
    
    def __init__(self,
                 dataset: torch.utils.data.Dataset,
                 num_replicas: Optional[int] = None,
                 rank: Optional[int] = None,
                 shuffle: bool = True,
                 round_up: bool = True):
        super().__init__(dataset, num_replicas=num_replicas, rank=rank)
        self.shuffle = shuffle
        self.round_up = round_up
        if self.round_up:
            self.total_size = self.num_samples * self.num_replicas
        else:
            self.total_size = len(self.dataset)

    def __iter__(self) -> Iterator[int]:
        """
        Returns an iterator over the indices of the dataset, shuffled if required, 
        with optional rounding up to make the number of samples divisible among replicas.

        Returns:
            Iterator[int]: An iterator over the indices for the current rank.
        """
        # deterministically shuffle based on epoch
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = torch.arange(len(self.dataset)).tolist()

        # add extra samples to make it evenly divisible
        if self.round_up:
            indices = (
                indices *
                int(self.total_size / len(indices) + 1))[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        if self.round_up:
            assert len(indices) == self.num_samples

        return iter(indices)
    
    
class DistributedWeightedRandomSampler(_DistributedSampler):    
    def __init__(self,
                 dataset: torch.utils.data.Dataset,
                 num_replicas: Optional[int] = None,
                 rank: Optional[int] = None,
                 shuffle: bool = True,
                 round_up: bool = True):
        super().__init__(dataset, num_replicas=num_replicas, rank=rank)
        self.shuffle = shuffle
        self.round_up = round_up
        if self.round_up:
            self.total_size = self.num_samples * self.num_replicas
        else:
            self.total_size = len(self.dataset)

    def __iter__(self) -> Iterator[int]:
        weights = self.dataset.weights
        indices = np.random.choice(len(weights), size=len(self.dataset), replace=True, p=weights)
        indices = indices.tolist()

        # add extra samples to make it evenly divisible
        if self.round_up:
            indices = (
                indices *
                int(self.total_size / len(indices) + 1))[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        if self.round_up:
            assert len(indices) == self.num_samples

        return iter(indices)