File size: 5,972 Bytes
2cdd41c
1615d09
 
2cdd41c
 
1615d09
 
 
2cdd41c
 
1615d09
 
 
 
 
 
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
2cdd41c
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
2cdd41c
 
 
1615d09
2cdd41c
 
 
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
 
2cdd41c
 
1615d09
 
 
2cdd41c
 
 
1615d09
 
 
2cdd41c
 
 
 
 
1615d09
2cdd41c
 
 
 
 
 
 
1615d09
2cdd41c
 
 
1615d09
 
 
2cdd41c
 
1615d09
2cdd41c
1615d09
 
 
2cdd41c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from copy import deepcopy

import numpy as np
from albumentations import ReplayCompose

from isegm.data.transforms import remove_image_only_transforms
from isegm.utils.misc import get_labels_with_sizes


class DSample:
    def __init__(
        self,
        image,
        encoded_masks,
        objects=None,
        objects_ids=None,
        ignore_ids=None,
        sample_id=None,
    ):
        self.image = image
        self.sample_id = sample_id

        if len(encoded_masks.shape) == 2:
            encoded_masks = encoded_masks[:, :, np.newaxis]
        self._encoded_masks = encoded_masks
        self._ignored_regions = []

        if objects_ids is not None:
            if not objects_ids or not isinstance(objects_ids[0], tuple):
                assert encoded_masks.shape[2] == 1
                objects_ids = [(0, obj_id) for obj_id in objects_ids]

            self._objects = dict()
            for indx, obj_mapping in enumerate(objects_ids):
                self._objects[indx] = {
                    "parent": None,
                    "mapping": obj_mapping,
                    "children": [],
                }

            if ignore_ids:
                if isinstance(ignore_ids[0], tuple):
                    self._ignored_regions = ignore_ids
                else:
                    self._ignored_regions = [(0, region_id) for region_id in ignore_ids]
        else:
            self._objects = deepcopy(objects)

        self._augmented = False
        self._soft_mask_aug = None
        self._original_data = self.image, self._encoded_masks, deepcopy(self._objects)

    def augment(self, augmentator):
        self.reset_augmentation()
        aug_output = augmentator(image=self.image, mask=self._encoded_masks)
        self.image = aug_output["image"]
        self._encoded_masks = aug_output["mask"]

        aug_replay = aug_output.get("replay", None)
        if aug_replay:
            assert len(self._ignored_regions) == 0
            mask_replay = remove_image_only_transforms(aug_replay)
            self._soft_mask_aug = ReplayCompose._restore_for_replay(mask_replay)

        self._compute_objects_areas()
        self.remove_small_objects(min_area=1)

        self._augmented = True

    def reset_augmentation(self):
        if not self._augmented:
            return
        orig_image, orig_masks, orig_objects = self._original_data
        self.image = orig_image
        self._encoded_masks = orig_masks
        self._objects = deepcopy(orig_objects)
        self._augmented = False
        self._soft_mask_aug = None

    def remove_small_objects(self, min_area):
        if self._objects and not "area" in list(self._objects.values())[0]:
            self._compute_objects_areas()

        for obj_id, obj_info in list(self._objects.items()):
            if obj_info["area"] < min_area:
                self._remove_object(obj_id)

    def get_object_mask(self, obj_id):
        layer_indx, mask_id = self._objects[obj_id]["mapping"]
        obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32)
        if self._ignored_regions:
            for layer_indx, mask_id in self._ignored_regions:
                ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id
                obj_mask[ignore_mask] = -1

        return obj_mask

    def get_soft_object_mask(self, obj_id):
        assert self._soft_mask_aug is not None
        original_encoded_masks = self._original_data[1]
        layer_indx, mask_id = self._objects[obj_id]["mapping"]
        obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(
            np.float32
        )
        obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)[
            "image"
        ]
        return np.clip(obj_mask, 0, 1)

    def get_background_mask(self):
        return np.max(self._encoded_masks, axis=2) == 0

    @property
    def objects_ids(self):
        return list(self._objects.keys())

    @property
    def gt_mask(self):
        assert len(self._objects) == 1
        return self.get_object_mask(self.objects_ids[0])

    @property
    def root_objects(self):
        return [
            obj_id
            for obj_id, obj_info in self._objects.items()
            if obj_info["parent"] is None
        ]

    def _compute_objects_areas(self):
        inverse_index = {
            node["mapping"]: node_id for node_id, node in self._objects.items()
        }
        ignored_regions_keys = set(self._ignored_regions)

        for layer_indx in range(self._encoded_masks.shape[2]):
            objects_ids, objects_areas = get_labels_with_sizes(
                self._encoded_masks[:, :, layer_indx]
            )
            for obj_id, obj_area in zip(objects_ids, objects_areas):
                inv_key = (layer_indx, obj_id)
                if inv_key in ignored_regions_keys:
                    continue
                try:
                    self._objects[inverse_index[inv_key]]["area"] = obj_area
                    del inverse_index[inv_key]
                except KeyError:
                    layer = self._encoded_masks[:, :, layer_indx]
                    layer[layer == obj_id] = 0
                    self._encoded_masks[:, :, layer_indx] = layer

        for obj_id in inverse_index.values():
            self._objects[obj_id]["area"] = 0

    def _remove_object(self, obj_id):
        obj_info = self._objects[obj_id]
        obj_parent = obj_info["parent"]
        for child_id in obj_info["children"]:
            self._objects[child_id]["parent"] = obj_parent

        if obj_parent is not None:
            parent_children = self._objects[obj_parent]["children"]
            parent_children = [x for x in parent_children if x != obj_id]
            self._objects[obj_parent]["children"] = (
                parent_children + obj_info["children"]
            )

        del self._objects[obj_id]

    def __len__(self):
        return len(self._objects)