File size: 7,354 Bytes
8e5cc83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import random
from copy import deepcopy
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler


class SAMDistributedSampler(DistributedSampler):
    """

    Modified from https://github.com/pytorch/pytorch/blob/97261be0a8f09bed9ab95d0cee82e75eebd249c3/torch/utils/data/distributed.py.

    """

    def __init__(

        self,

        dataset: Dataset,

        num_replicas: Optional[int] = None,

        rank: Optional[int] = None,

        shuffle: bool = True,

        seed: int = 0,

        drop_last: bool = False,

        sub_epochs_per_epoch: int = 1,

    ) -> None:
        super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)

        self.sub_epoch = 0
        self.sub_epochs_per_epoch = sub_epochs_per_epoch
        self.set_sub_num_samples()

    def __iter__(self):
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        indices = indices[(self.sub_epoch % self.sub_epochs_per_epoch) :: self.sub_epochs_per_epoch]

        return iter(indices)

    def __len__(self) -> int:
        return self.sub_num_samples

    def set_sub_num_samples(self) -> int:
        self.sub_num_samples = self.num_samples // self.sub_epochs_per_epoch
        if self.sub_num_samples % self.sub_epochs_per_epoch > self.sub_epoch:
            self.sub_num_samples += 1

    def set_epoch_and_sub_epoch(self, epoch: int, sub_epoch: int) -> None:
        r"""

        Set the epoch for this sampler.



        When :attr:`shuffle=True`, this ensures all replicas

        use a different random ordering for each epoch. Otherwise, the next iteration of this

        sampler will yield the same ordering.



        Args:

            epoch (int): Epoch number.

            sub_epoch (int): Sub epoch number.

        """
        self.epoch = epoch
        self.sub_epoch = sub_epoch
        self.set_sub_num_samples()


class RandomHFlip(object):
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, sample):
        image, masks, points, bboxs, shape = (
            sample["image"],
            sample["masks"],
            sample["points"],
            sample["bboxs"],
            sample["shape"],
        )

        if random.random() >= self.prob:
            image = torch.flip(image, dims=[2])
            masks = torch.flip(masks, dims=[2])
            points = deepcopy(points).to(torch.float)
            bboxs = deepcopy(bboxs).to(torch.float)
            points[:, 0] = shape[-1] - points[:, 0]
            bboxs[:, 0] = shape[-1] - bboxs[:, 2] - bboxs[:, 0]

        return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape}


class ResizeLongestSide(object):
    """

    Modified from https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/utils/transforms.py.

    """

    def __init__(self, target_length: int) -> None:
        self.target_length = target_length

    def apply_image(self, image: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
        target_size = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
        return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)

    def apply_boxes(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
        """

        Expects a torch tensor with shape Bx4. Requires the original image

        size in (H, W) format.

        """
        boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
        return boxes.reshape(-1, 4)

    def apply_coords(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
        """

        Expects a torch tensor with length 2 in the last dimension. Requires the

        original image size in (H, W) format.

        """
        old_h, old_w = original_size
        new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
        coords = deepcopy(coords).to(torch.float)
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)
        return coords

    @staticmethod
    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
        """

        Compute the output size given input size and target long side length.

        """
        scale = long_side_length * 1.0 / max(oldh, oldw)
        newh, neww = oldh * scale, oldw * scale
        neww = int(neww + 0.5)
        newh = int(newh + 0.5)
        return (newh, neww)

    def __call__(self, sample):
        image, masks, points, bboxs, shape = (
            sample["image"],
            sample["masks"],
            sample["points"],
            sample["bboxs"],
            sample["shape"],
        )

        image = self.apply_image(image.unsqueeze(0), shape).squeeze(0)
        masks = self.apply_image(masks.unsqueeze(1), shape).squeeze(1)
        points = self.apply_coords(points, shape)
        bboxs = self.apply_boxes(bboxs, shape)

        return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape}


class Normalize_and_Pad(object):
    def __init__(self, target_length: int) -> None:
        self.target_length = target_length
        self.transform = transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])

    def __call__(self, sample):
        image, masks, points, bboxs, shape = (
            sample["image"],
            sample["masks"],
            sample["points"],
            sample["bboxs"],
            sample["shape"],
        )

        h, w = image.shape[-2:]
        image = self.transform(image)

        padh = self.target_length - h
        padw = self.target_length - w

        image = F.pad(image.unsqueeze(0), (0, padw, 0, padh), value=0).squeeze(0)
        masks = F.pad(masks.unsqueeze(1), (0, padw, 0, padh), value=0).squeeze(1)

        return {"image": image, "masks": masks, "points": points, "bboxs": bboxs, "shape": shape}