Spaces:
Runtime error
Runtime error
File size: 5,972 Bytes
2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c |
|
from copy import deepcopy
import numpy as np
from albumentations import ReplayCompose
from isegm.data.transforms import remove_image_only_transforms
from isegm.utils.misc import get_labels_with_sizes
class DSample:
def __init__(
self,
image,
encoded_masks,
objects=None,
objects_ids=None,
ignore_ids=None,
sample_id=None,
):
self.image = image
self.sample_id = sample_id
if len(encoded_masks.shape) == 2:
encoded_masks = encoded_masks[:, :, np.newaxis]
self._encoded_masks = encoded_masks
self._ignored_regions = []
if objects_ids is not None:
if not objects_ids or not isinstance(objects_ids[0], tuple):
assert encoded_masks.shape[2] == 1
objects_ids = [(0, obj_id) for obj_id in objects_ids]
self._objects = dict()
for indx, obj_mapping in enumerate(objects_ids):
self._objects[indx] = {
"parent": None,
"mapping": obj_mapping,
"children": [],
}
if ignore_ids:
if isinstance(ignore_ids[0], tuple):
self._ignored_regions = ignore_ids
else:
self._ignored_regions = [(0, region_id) for region_id in ignore_ids]
else:
self._objects = deepcopy(objects)
self._augmented = False
self._soft_mask_aug = None
self._original_data = self.image, self._encoded_masks, deepcopy(self._objects)
def augment(self, augmentator):
self.reset_augmentation()
aug_output = augmentator(image=self.image, mask=self._encoded_masks)
self.image = aug_output["image"]
self._encoded_masks = aug_output["mask"]
aug_replay = aug_output.get("replay", None)
if aug_replay:
assert len(self._ignored_regions) == 0
mask_replay = remove_image_only_transforms(aug_replay)
self._soft_mask_aug = ReplayCompose._restore_for_replay(mask_replay)
self._compute_objects_areas()
self.remove_small_objects(min_area=1)
self._augmented = True
def reset_augmentation(self):
if not self._augmented:
return
orig_image, orig_masks, orig_objects = self._original_data
self.image = orig_image
self._encoded_masks = orig_masks
self._objects = deepcopy(orig_objects)
self._augmented = False
self._soft_mask_aug = None
def remove_small_objects(self, min_area):
if self._objects and not "area" in list(self._objects.values())[0]:
self._compute_objects_areas()
for obj_id, obj_info in list(self._objects.items()):
if obj_info["area"] < min_area:
self._remove_object(obj_id)
def get_object_mask(self, obj_id):
layer_indx, mask_id = self._objects[obj_id]["mapping"]
obj_mask = (self._encoded_masks[:, :, layer_indx] == mask_id).astype(np.int32)
if self._ignored_regions:
for layer_indx, mask_id in self._ignored_regions:
ignore_mask = self._encoded_masks[:, :, layer_indx] == mask_id
obj_mask[ignore_mask] = -1
return obj_mask
def get_soft_object_mask(self, obj_id):
assert self._soft_mask_aug is not None
original_encoded_masks = self._original_data[1]
layer_indx, mask_id = self._objects[obj_id]["mapping"]
obj_mask = (original_encoded_masks[:, :, layer_indx] == mask_id).astype(
np.float32
)
obj_mask = self._soft_mask_aug(image=obj_mask, mask=original_encoded_masks)[
"image"
]
return np.clip(obj_mask, 0, 1)
def get_background_mask(self):
return np.max(self._encoded_masks, axis=2) == 0
@property
def objects_ids(self):
return list(self._objects.keys())
@property
def gt_mask(self):
assert len(self._objects) == 1
return self.get_object_mask(self.objects_ids[0])
@property
def root_objects(self):
return [
obj_id
for obj_id, obj_info in self._objects.items()
if obj_info["parent"] is None
]
def _compute_objects_areas(self):
inverse_index = {
node["mapping"]: node_id for node_id, node in self._objects.items()
}
ignored_regions_keys = set(self._ignored_regions)
for layer_indx in range(self._encoded_masks.shape[2]):
objects_ids, objects_areas = get_labels_with_sizes(
self._encoded_masks[:, :, layer_indx]
)
for obj_id, obj_area in zip(objects_ids, objects_areas):
inv_key = (layer_indx, obj_id)
if inv_key in ignored_regions_keys:
continue
try:
self._objects[inverse_index[inv_key]]["area"] = obj_area
del inverse_index[inv_key]
except KeyError:
layer = self._encoded_masks[:, :, layer_indx]
layer[layer == obj_id] = 0
self._encoded_masks[:, :, layer_indx] = layer
for obj_id in inverse_index.values():
self._objects[obj_id]["area"] = 0
def _remove_object(self, obj_id):
obj_info = self._objects[obj_id]
obj_parent = obj_info["parent"]
for child_id in obj_info["children"]:
self._objects[child_id]["parent"] = obj_parent
if obj_parent is not None:
parent_children = self._objects[obj_parent]["children"]
parent_children = [x for x in parent_children if x != obj_id]
self._objects[obj_parent]["children"] = (
parent_children + obj_info["children"]
)
del self._objects[obj_id]
def __len__(self):
return len(self._objects)
|