File size: 13,049 Bytes
2cdd41c
 
 
1615d09
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
2cdd41c
 
1615d09
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
2cdd41c
 
 
 
1615d09
2cdd41c
1615d09
2cdd41c
 
 
 
1615d09
 
 
 
 
2cdd41c
 
 
1615d09
 
 
 
 
 
 
 
2cdd41c
 
 
1615d09
 
2cdd41c
1615d09
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdd41c
 
 
1615d09
 
 
 
2cdd41c
 
1615d09
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
 
2cdd41c
1615d09
 
 
2cdd41c
 
 
1615d09
 
 
2cdd41c
 
 
1615d09
 
 
2cdd41c
 
 
 
 
 
 
 
1615d09
 
 
2cdd41c
1615d09
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
2cdd41c
1615d09
 
 
2cdd41c
 
 
 
 
 
 
1615d09
 
 
2cdd41c
 
 
 
 
 
 
 
 
1615d09
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
import math
import random
from functools import lru_cache

import cv2
import numpy as np

from .sample import DSample


class BasePointSampler:
    def __init__(self):
        self._selected_mask = None
        self._selected_masks = None

    def sample_object(self, sample: DSample):
        raise NotImplementedError

    def sample_points(self):
        raise NotImplementedError

    @property
    def selected_mask(self):
        assert self._selected_mask is not None
        return self._selected_mask

    @selected_mask.setter
    def selected_mask(self, mask):
        self._selected_mask = mask[np.newaxis, :].astype(np.float32)


class MultiPointSampler(BasePointSampler):
    def __init__(
        self,
        max_num_points,
        prob_gamma=0.7,
        expand_ratio=0.1,
        positive_erode_prob=0.9,
        positive_erode_iters=3,
        negative_bg_prob=0.1,
        negative_other_prob=0.4,
        negative_border_prob=0.5,
        merge_objects_prob=0.0,
        max_num_merged_objects=2,
        use_hierarchy=False,
        soft_targets=False,
        first_click_center=False,
        only_one_first_click=False,
        sfc_inner_k=1.7,
        sfc_full_inner_prob=0.0,
    ):
        super().__init__()
        self.max_num_points = max_num_points
        self.expand_ratio = expand_ratio
        self.positive_erode_prob = positive_erode_prob
        self.positive_erode_iters = positive_erode_iters
        self.merge_objects_prob = merge_objects_prob
        self.use_hierarchy = use_hierarchy
        self.soft_targets = soft_targets
        self.first_click_center = first_click_center
        self.only_one_first_click = only_one_first_click
        self.sfc_inner_k = sfc_inner_k
        self.sfc_full_inner_prob = sfc_full_inner_prob

        if max_num_merged_objects == -1:
            max_num_merged_objects = max_num_points
        self.max_num_merged_objects = max_num_merged_objects

        self.neg_strategies = ["bg", "other", "border"]
        self.neg_strategies_prob = [
            negative_bg_prob,
            negative_other_prob,
            negative_border_prob,
        ]
        assert math.isclose(sum(self.neg_strategies_prob), 1.0)

        self._pos_probs = generate_probs(max_num_points, gamma=prob_gamma)
        self._neg_probs = generate_probs(max_num_points + 1, gamma=prob_gamma)
        self._neg_masks = None

    def sample_object(self, sample: DSample):
        if len(sample) == 0:
            bg_mask = sample.get_background_mask()
            self.selected_mask = np.zeros_like(bg_mask, dtype=np.float32)
            self._selected_masks = [[]]
            self._neg_masks = {strategy: bg_mask for strategy in self.neg_strategies}
            self._neg_masks["required"] = []
            return

        gt_mask, pos_masks, neg_masks = self._sample_mask(sample)
        binary_gt_mask = gt_mask > 0.5 if self.soft_targets else gt_mask > 0

        self.selected_mask = gt_mask
        self._selected_masks = pos_masks

        neg_mask_bg = np.logical_not(binary_gt_mask)
        neg_mask_border = self._get_border_mask(binary_gt_mask)
        if len(sample) <= len(self._selected_masks):
            neg_mask_other = neg_mask_bg
        else:
            neg_mask_other = np.logical_and(
                np.logical_not(sample.get_background_mask()),
                np.logical_not(binary_gt_mask),
            )

        self._neg_masks = {
            "bg": neg_mask_bg,
            "other": neg_mask_other,
            "border": neg_mask_border,
            "required": neg_masks,
        }

    def _sample_mask(self, sample: DSample):
        root_obj_ids = sample.root_objects

        if len(root_obj_ids) > 1 and random.random() < self.merge_objects_prob:
            max_selected_objects = min(len(root_obj_ids), self.max_num_merged_objects)
            num_selected_objects = np.random.randint(2, max_selected_objects + 1)
            random_ids = random.sample(root_obj_ids, num_selected_objects)
        else:
            random_ids = [random.choice(root_obj_ids)]

        gt_mask = None
        pos_segments = []
        neg_segments = []
        for obj_id in random_ids:
            (
                obj_gt_mask,
                obj_pos_segments,
                obj_neg_segments,
            ) = self._sample_from_masks_layer(obj_id, sample)
            if gt_mask is None:
                gt_mask = obj_gt_mask
            else:
                gt_mask = np.maximum(gt_mask, obj_gt_mask)

            pos_segments.extend(obj_pos_segments)
            neg_segments.extend(obj_neg_segments)

        pos_masks = [self._positive_erode(x) for x in pos_segments]
        neg_masks = [self._positive_erode(x) for x in neg_segments]

        return gt_mask, pos_masks, neg_masks

    def _sample_from_masks_layer(self, obj_id, sample: DSample):
        objs_tree = sample._objects

        if not self.use_hierarchy:
            node_mask = sample.get_object_mask(obj_id)
            gt_mask = (
                sample.get_soft_object_mask(obj_id) if self.soft_targets else node_mask
            )
            return gt_mask, [node_mask], []

        def _select_node(node_id):
            node_info = objs_tree[node_id]
            if not node_info["children"] or random.random() < 0.5:
                return node_id
            return _select_node(random.choice(node_info["children"]))

        selected_node = _select_node(obj_id)
        node_info = objs_tree[selected_node]
        node_mask = sample.get_object_mask(selected_node)
        gt_mask = (
            sample.get_soft_object_mask(selected_node)
            if self.soft_targets
            else node_mask
        )
        pos_mask = node_mask.copy()

        negative_segments = []
        if node_info["parent"] is not None and node_info["parent"] in objs_tree:
            parent_mask = sample.get_object_mask(node_info["parent"])
            negative_segments.append(
                np.logical_and(parent_mask, np.logical_not(node_mask))
            )

        for child_id in node_info["children"]:
            if objs_tree[child_id]["area"] / node_info["area"] < 0.10:
                child_mask = sample.get_object_mask(child_id)
                pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))

        if node_info["children"]:
            max_disabled_children = min(len(node_info["children"]), 3)
            num_disabled_children = np.random.randint(0, max_disabled_children + 1)
            disabled_children = random.sample(
                node_info["children"], num_disabled_children
            )

            for child_id in disabled_children:
                child_mask = sample.get_object_mask(child_id)
                pos_mask = np.logical_and(pos_mask, np.logical_not(child_mask))
                if self.soft_targets:
                    soft_child_mask = sample.get_soft_object_mask(child_id)
                    gt_mask = np.minimum(gt_mask, 1.0 - soft_child_mask)
                else:
                    gt_mask = np.logical_and(gt_mask, np.logical_not(child_mask))
                negative_segments.append(child_mask)

        return gt_mask, [pos_mask], negative_segments

    def sample_points(self):
        assert self._selected_mask is not None
        pos_points = self._multi_mask_sample_points(
            self._selected_masks,
            is_negative=[False] * len(self._selected_masks),
            with_first_click=self.first_click_center,
        )

        neg_strategy = [
            (self._neg_masks[k], prob)
            for k, prob in zip(self.neg_strategies, self.neg_strategies_prob)
        ]
        neg_masks = self._neg_masks["required"] + [neg_strategy]
        neg_points = self._multi_mask_sample_points(
            neg_masks, is_negative=[False] * len(self._neg_masks["required"]) + [True]
        )

        return pos_points + neg_points

    def _multi_mask_sample_points(
        self, selected_masks, is_negative, with_first_click=False
    ):
        selected_masks = selected_masks[: self.max_num_points]

        each_obj_points = [
            self._sample_points(
                mask, is_negative=is_negative[i], with_first_click=with_first_click
            )
            for i, mask in enumerate(selected_masks)
        ]
        each_obj_points = [x for x in each_obj_points if len(x) > 0]

        points = []
        if len(each_obj_points) == 1:
            points = each_obj_points[0]
        elif len(each_obj_points) > 1:
            if self.only_one_first_click:
                each_obj_points = each_obj_points[:1]

            points = [obj_points[0] for obj_points in each_obj_points]

            aggregated_masks_with_prob = []
            for indx, x in enumerate(selected_masks):
                if (
                    isinstance(x, (list, tuple))
                    and x
                    and isinstance(x[0], (list, tuple))
                ):
                    for t, prob in x:
                        aggregated_masks_with_prob.append(
                            (t, prob / len(selected_masks))
                        )
                else:
                    aggregated_masks_with_prob.append((x, 1.0 / len(selected_masks)))

            other_points_union = self._sample_points(
                aggregated_masks_with_prob, is_negative=True
            )
            if len(other_points_union) + len(points) <= self.max_num_points:
                points.extend(other_points_union)
            else:
                points.extend(
                    random.sample(other_points_union, self.max_num_points - len(points))
                )

        if len(points) < self.max_num_points:
            points.extend([(-1, -1, -1)] * (self.max_num_points - len(points)))

        return points

    def _sample_points(self, mask, is_negative=False, with_first_click=False):
        if is_negative:
            num_points = np.random.choice(
                np.arange(self.max_num_points + 1), p=self._neg_probs
            )
        else:
            num_points = 1 + np.random.choice(
                np.arange(self.max_num_points), p=self._pos_probs
            )

        indices_probs = None
        if isinstance(mask, (list, tuple)):
            indices_probs = [x[1] for x in mask]
            indices = [(np.argwhere(x), prob) for x, prob in mask]
            if indices_probs:
                assert math.isclose(sum(indices_probs), 1.0)
        else:
            indices = np.argwhere(mask)

        points = []
        for j in range(num_points):
            first_click = with_first_click and j == 0 and indices_probs is None

            if first_click:
                point_indices = get_point_candidates(
                    mask, k=self.sfc_inner_k, full_prob=self.sfc_full_inner_prob
                )
            elif indices_probs:
                point_indices_indx = np.random.choice(
                    np.arange(len(indices)), p=indices_probs
                )
                point_indices = indices[point_indices_indx][0]
            else:
                point_indices = indices

            num_indices = len(point_indices)
            if num_indices > 0:
                point_indx = 0 if first_click else 100
                click = point_indices[np.random.randint(0, num_indices)].tolist() + [
                    point_indx
                ]
                points.append(click)

        return points

    def _positive_erode(self, mask):
        if random.random() > self.positive_erode_prob:
            return mask

        kernel = np.ones((3, 3), np.uint8)
        eroded_mask = cv2.erode(
            mask.astype(np.uint8), kernel, iterations=self.positive_erode_iters
        ).astype(np.bool)

        if eroded_mask.sum() > 10:
            return eroded_mask
        else:
            return mask

    def _get_border_mask(self, mask):
        expand_r = int(np.ceil(self.expand_ratio * np.sqrt(mask.sum())))
        kernel = np.ones((3, 3), np.uint8)
        expanded_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=expand_r)
        expanded_mask[mask.astype(np.bool)] = 0
        return expanded_mask


@lru_cache(maxsize=None)
def generate_probs(max_num_points, gamma):
    probs = []
    last_value = 1
    for i in range(max_num_points):
        probs.append(last_value)
        last_value *= gamma

    probs = np.array(probs)
    probs /= probs.sum()

    return probs


def get_point_candidates(obj_mask, k=1.7, full_prob=0.0):
    if full_prob > 0 and random.random() < full_prob:
        return obj_mask

    padded_mask = np.pad(obj_mask, ((1, 1), (1, 1)), "constant")

    dt = cv2.distanceTransform(padded_mask.astype(np.uint8), cv2.DIST_L2, 0)[1:-1, 1:-1]
    if k > 0:
        inner_mask = dt > dt.max() / k
        return np.argwhere(inner_mask)
    else:
        prob_map = dt.flatten()
        prob_map /= max(prob_map.sum(), 1e-6)
        click_indx = np.random.choice(len(prob_map), p=prob_map)
        click_coords = np.unravel_index(click_indx, dt.shape)
        return np.array([click_coords])