import pickle as pkl from pathlib import Path import cv2 import numpy as np from scipy.io import loadmat from isegm.data.base import ISDataset from isegm.data.sample import DSample from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes class SBDDataset(ISDataset): def __init__(self, dataset_path, split="train", buggy_mask_thresh=0.08, **kwargs): super(SBDDataset, self).__init__(**kwargs) assert split in {"train", "val"} self.dataset_path = Path(dataset_path) self.dataset_split = split self._images_path = self.dataset_path / "img" self._insts_path = self.dataset_path / "inst" self._buggy_objects = dict() self._buggy_mask_thresh = buggy_mask_thresh with open(self.dataset_path / f"{split}.txt", "r") as f: self.dataset_samples = [x.strip() for x in f.readlines()] def get_sample(self, index): image_name = self.dataset_samples[index] image_path = str(self._images_path / f"{image_name}.jpg") inst_info_path = str(self._insts_path / f"{image_name}.mat") image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype( np.int32 ) instances_mask = self.remove_buggy_masks(index, instances_mask) instances_ids, _ = get_labels_with_sizes(instances_mask) return DSample( image, instances_mask, objects_ids=instances_ids, sample_id=index ) def remove_buggy_masks(self, index, instances_mask): if self._buggy_mask_thresh > 0.0: buggy_image_objects = self._buggy_objects.get(index, None) if buggy_image_objects is None: buggy_image_objects = [] instances_ids, _ = get_labels_with_sizes(instances_mask) for obj_id in instances_ids: obj_mask = instances_mask == obj_id mask_area = obj_mask.sum() bbox = get_bbox_from_mask(obj_mask) bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) obj_area_ratio = mask_area / bbox_area if obj_area_ratio < self._buggy_mask_thresh: buggy_image_objects.append(obj_id) self._buggy_objects[index] = buggy_image_objects for obj_id in buggy_image_objects: instances_mask[instances_mask == obj_id] = 0 return instances_mask class SBDEvaluationDataset(ISDataset): def __init__(self, dataset_path, split="val", **kwargs): super(SBDEvaluationDataset, self).__init__(**kwargs) assert split in {"train", "val"} self.dataset_path = Path(dataset_path) self.dataset_split = split self._images_path = self.dataset_path / "img" self._insts_path = self.dataset_path / "inst" with open(self.dataset_path / f"{split}.txt", "r") as f: self.dataset_samples = [x.strip() for x in f.readlines()] self.dataset_samples = self.get_sbd_images_and_ids_list() def get_sample(self, index) -> DSample: image_name, instance_id = self.dataset_samples[index] image_path = str(self._images_path / f"{image_name}.jpg") inst_info_path = str(self._insts_path / f"{image_name}.mat") image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype( np.int32 ) instances_mask[instances_mask != instance_id] = 0 instances_mask[instances_mask > 0] = 1 return DSample(image, instances_mask, objects_ids=[1], sample_id=index) def get_sbd_images_and_ids_list(self): pkl_path = self.dataset_path / f"{self.dataset_split}_images_and_ids_list.pkl" if pkl_path.exists(): with open(str(pkl_path), "rb") as fp: images_and_ids_list = pkl.load(fp) else: images_and_ids_list = [] for sample in self.dataset_samples: inst_info_path = str(self._insts_path / f"{sample}.mat") instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype( np.int32 ) instances_ids, _ = get_labels_with_sizes(instances_mask) for instances_id in instances_ids: images_and_ids_list.append((sample, instances_id)) with open(str(pkl_path), "wb") as fp: pkl.dump(images_and_ids_list, fp) return images_and_ids_list