File size: 10,646 Bytes
fcdfd72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import cv2
import torch
import numpy as np
import pycocotools.mask as mask_utils

# transpose
FLIP_LEFT_RIGHT = 0
FLIP_TOP_BOTTOM = 1


class MaskList(object):
    """
    This class is unfinished and not meant for use yet
    It is supposed to contain the binary masks for all instances in a list of 2D tensors (H, W)
    """

    def __init__(self, masks, size, mode):
        assert(isinstance(masks, list))
        assert(mode in ['mask', 'rle'])
        self.masks = masks
        self.size = size # (image_width, image_height)
        self.mode = mode

    def transpose(self, method):
        assert (self.mode == "mask"), "RLE masks cannot be transposed. Please convert them to binary first."
        if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
            raise NotImplementedError(
                "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
            )

        # width, height = self.size
        masks = np.array(self.masks)
        if masks.ndim == 2:
            masks = np.expand_dims(masks, axis=0)
        if method == FLIP_LEFT_RIGHT:
            masks = np.flip(masks, axis=2)
        elif method == FLIP_TOP_BOTTOM:
            masks = np.flip(masks, axis=1)
        flipped_masks = np.split(masks, masks.shape[0])
        flipped_masks = [mask.squeeze(0) for mask in flipped_masks]
        return MaskList(flipped_masks, self.size, self.mode)

    def resize(self, size, *args, **kwargs):
        """
        Resize the binary mask.
        :param size: tuple, (image_width, image_height)
        :param args:
        :param kwargs:
        :return:
        """
        assert(self.mode == "mask"), "RLE masks cannot be resized. Please convert them to binary first."
        cat_mask = np.array(self.masks)

        cat_mask = cat_mask.transpose(1, 2, 0)
        cat_mask *= 255
        cat_mask = cat_mask.astype(np.uint8)
        resized_mask = cv2.resize(cat_mask, size)
        if resized_mask.ndim == 2:
            resized_mask = np.expand_dims(resized_mask, axis=2)
        try:
            resized_mask = resized_mask.transpose(2, 0, 1)
        except ValueError:
            print("?")
        resized_mask = resized_mask.astype(int)
        resized_mask = resized_mask // 255
        # # visualize to check mask correctness
        # from matplotlib import pyplot as plt
        # plt.figure()
        # plt.imshow(resized_mask[0]*255, cmap='gray')
        # plt.show()
        mask_list = np.split(resized_mask, resized_mask.shape[0])
        mask_list = [mask.squeeze(0) for mask in mask_list]
        return MaskList(mask_list, size, "mask")

    def pad(self, size):
        """
        pad the binary masks according to the new size. New size must be larger than original size in all dimensions
        :param size: New image size, (image_width, image_height)
        :return:
        """
        assert(size[0] >= self.size[0] and size[1] >= self.size[1]), "New size must be larger than original size in all dimensions"
        cat_mask = np.array(self.masks)
        if cat_mask.ndim == 2:
            cat_mask = np.expand_dims(cat_mask, axis=0)
        padded_mask = np.zeros([len(self.masks), size[1], size[0]])
        padded_mask[:, :cat_mask.shape[1], :cat_mask.shape[2]] = cat_mask
        # # visualize to check mask correctness
        # from matplotlib import pyplot as plt
        # plt.figure()
        # plt.imshow(padded_mask[1]*255, cmap='gray')
        # plt.show()
        mask_list = np.split(padded_mask, padded_mask.shape[0])
        mask_list = [mask.squeeze(0) for mask in mask_list]
        return MaskList(mask_list, size, "mask")

    def convert(self, mode):
        """
        Convert mask from between mode "mask" and mode "rle"
        :param mode:
        :return:
        """
        if mode == self.mode:
            return self
        elif mode == "rle" and self.mode == "mask":
            # use pycocotools to encode binary masks to rle
            rle_mask_list = mask_utils.encode(np.asfortranarray(np.array(self.masks).transpose(1, 2, 0).astype(np.uint8)))
            return MaskList(rle_mask_list, self.size, "rle")
        elif mode == "mask" and self.mode == "rle":
            # use pycocotools to decode rle to binary masks
            bimasks = mask_utils.decode(self.masks)
            mask_list = np.split(bimasks.transpose(2, 0, 1), bimasks.shape[2])
            mask_list = [mask.squeeze(0) for mask in mask_list]
            return MaskList(mask_list, self.size, "mask")

    def bbox(self, bbox_mode="xyxy"):
        """
        Generate a bounding box according to the binary mask
        :param bbox_mode:
        :return:
        """
        pass

    def __len__(self):
        return len(self.masks)

    def __repr__(self):
        s = self.__class__.__name__ + "("
        s += "num_masks={}, ".format(len(self))
        s += "image_width={}, ".format(self.size[0])
        s += "image_height={}, ".format(self.size[1])
        s += "mode={})".format(self.mode)
        return s


class Polygons(object):
    """
    This class holds a set of polygons that represents a single instance
    of an object mask. The object can be represented as a set of
    polygons
    """

    def __init__(self, polygons, size, mode):
        # assert isinstance(polygons, list), '{}'.format(polygons)
        if isinstance(polygons, list):
            polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons]
        elif isinstance(polygons, Polygons):
            polygons = polygons.polygons

        self.polygons = polygons
        self.size = size
        self.mode = mode

    def transpose(self, method):
        if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
            raise NotImplementedError(
                "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
            )

        flipped_polygons = []
        width, height = self.size
        if method == FLIP_LEFT_RIGHT:
            dim = width
            idx = 0
        elif method == FLIP_TOP_BOTTOM:
            dim = height
            idx = 1

        for poly in self.polygons:
            p = poly.clone()
            TO_REMOVE = 1
            p[idx::2] = dim - poly[idx::2] - TO_REMOVE
            flipped_polygons.append(p)

        return Polygons(flipped_polygons, size=self.size, mode=self.mode)

    def crop(self, box):
        w, h = box[2] - box[0], box[3] - box[1]

        # TODO chck if necessary
        w = max(w, 1)
        h = max(h, 1)

        cropped_polygons = []
        for poly in self.polygons:
            p = poly.clone()
            p[0::2] = p[0::2] - box[0]  # .clamp(min=0, max=w)
            p[1::2] = p[1::2] - box[1]  # .clamp(min=0, max=h)
            cropped_polygons.append(p)

        return Polygons(cropped_polygons, size=(w, h), mode=self.mode)

    def resize(self, size, *args, **kwargs):
        ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
        if ratios[0] == ratios[1]:
            ratio = ratios[0]
            scaled_polys = [p * ratio for p in self.polygons]
            return Polygons(scaled_polys, size, mode=self.mode)

        ratio_w, ratio_h = ratios
        scaled_polygons = []
        for poly in self.polygons:
            p = poly.clone()
            p[0::2] *= ratio_w
            p[1::2] *= ratio_h
            scaled_polygons.append(p)

        return Polygons(scaled_polygons, size=size, mode=self.mode)

    def convert(self, mode):
        width, height = self.size
        if mode == "mask":
            rles = mask_utils.frPyObjects(
                [p.detach().numpy() for p in self.polygons], height, width
            )
            rle = mask_utils.merge(rles)
            mask = mask_utils.decode(rle)
            mask = torch.from_numpy(mask)
            # TODO add squeeze?
            return mask

    def __repr__(self):
        s = self.__class__.__name__ + "("
        s += "num_polygons={}, ".format(len(self.polygons))
        s += "image_width={}, ".format(self.size[0])
        s += "image_height={}, ".format(self.size[1])
        s += "mode={})".format(self.mode)
        return s


class SegmentationMask(object):
    """
    This class stores the segmentations for all objects in the image
    """

    def __init__(self, polygons, size, mode=None):
        """
        Arguments:
            polygons: a list of list of lists of numbers. The first
                level of the list correspond to individual instances,
                the second level to all the polygons that compose the
                object, and the third level to the polygon coordinates.
        """
        assert isinstance(polygons, list)

        self.polygons = [Polygons(p, size, mode) for p in polygons]
        self.size = size
        self.mode = mode

    def transpose(self, method):
        if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
            raise NotImplementedError(
                "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
            )

        flipped = []
        for polygon in self.polygons:
            flipped.append(polygon.transpose(method))
        return SegmentationMask(flipped, size=self.size, mode=self.mode)

    def crop(self, box):
        w, h = box[2] - box[0], box[3] - box[1]
        cropped = []
        for polygon in self.polygons:
            cropped.append(polygon.crop(box))
        return SegmentationMask(cropped, size=(w, h), mode=self.mode)

    def resize(self, size, *args, **kwargs):
        scaled = []
        for polygon in self.polygons:
            scaled.append(polygon.resize(size, *args, **kwargs))
        return SegmentationMask(scaled, size=size, mode=self.mode)

    def to(self, *args, **kwargs):
        return self

    def __getitem__(self, item):
        if isinstance(item, (int, slice)):
            selected_polygons = [self.polygons[item]]
        else:
            # advanced indexing on a single dimension
            selected_polygons = []
            if isinstance(item, torch.Tensor) and item.dtype == torch.bool:
                item = item.nonzero()
                item = item.squeeze(1) if item.numel() > 0 else item
                item = item.tolist()
            for i in item:
                selected_polygons.append(self.polygons[i])
        return SegmentationMask(selected_polygons, size=self.size, mode=self.mode)

    def __iter__(self):
        return iter(self.polygons)

    def __repr__(self):
        s = self.__class__.__name__ + "("
        s += "num_instances={}, ".format(len(self.polygons))
        s += "image_width={}, ".format(self.size[0])
        s += "image_height={})".format(self.size[1])
        return s