Spaces:
Runtime error
Runtime error
import pickle as pkl | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
from scipy.io import loadmat | |
from isegm.data.base import ISDataset | |
from isegm.data.sample import DSample | |
from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes | |
class SBDDataset(ISDataset): | |
def __init__(self, dataset_path, split="train", buggy_mask_thresh=0.08, **kwargs): | |
super(SBDDataset, self).__init__(**kwargs) | |
assert split in {"train", "val"} | |
self.dataset_path = Path(dataset_path) | |
self.dataset_split = split | |
self._images_path = self.dataset_path / "img" | |
self._insts_path = self.dataset_path / "inst" | |
self._buggy_objects = dict() | |
self._buggy_mask_thresh = buggy_mask_thresh | |
with open(self.dataset_path / f"{split}.txt", "r") as f: | |
self.dataset_samples = [x.strip() for x in f.readlines()] | |
def get_sample(self, index): | |
image_name = self.dataset_samples[index] | |
image_path = str(self._images_path / f"{image_name}.jpg") | |
inst_info_path = str(self._insts_path / f"{image_name}.mat") | |
image = cv2.imread(image_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype( | |
np.int32 | |
) | |
instances_mask = self.remove_buggy_masks(index, instances_mask) | |
instances_ids, _ = get_labels_with_sizes(instances_mask) | |
return DSample( | |
image, instances_mask, objects_ids=instances_ids, sample_id=index | |
) | |
def remove_buggy_masks(self, index, instances_mask): | |
if self._buggy_mask_thresh > 0.0: | |
buggy_image_objects = self._buggy_objects.get(index, None) | |
if buggy_image_objects is None: | |
buggy_image_objects = [] | |
instances_ids, _ = get_labels_with_sizes(instances_mask) | |
for obj_id in instances_ids: | |
obj_mask = instances_mask == obj_id | |
mask_area = obj_mask.sum() | |
bbox = get_bbox_from_mask(obj_mask) | |
bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1) | |
obj_area_ratio = mask_area / bbox_area | |
if obj_area_ratio < self._buggy_mask_thresh: | |
buggy_image_objects.append(obj_id) | |
self._buggy_objects[index] = buggy_image_objects | |
for obj_id in buggy_image_objects: | |
instances_mask[instances_mask == obj_id] = 0 | |
return instances_mask | |
class SBDEvaluationDataset(ISDataset): | |
def __init__(self, dataset_path, split="val", **kwargs): | |
super(SBDEvaluationDataset, self).__init__(**kwargs) | |
assert split in {"train", "val"} | |
self.dataset_path = Path(dataset_path) | |
self.dataset_split = split | |
self._images_path = self.dataset_path / "img" | |
self._insts_path = self.dataset_path / "inst" | |
with open(self.dataset_path / f"{split}.txt", "r") as f: | |
self.dataset_samples = [x.strip() for x in f.readlines()] | |
self.dataset_samples = self.get_sbd_images_and_ids_list() | |
def get_sample(self, index) -> DSample: | |
image_name, instance_id = self.dataset_samples[index] | |
image_path = str(self._images_path / f"{image_name}.jpg") | |
inst_info_path = str(self._insts_path / f"{image_name}.mat") | |
image = cv2.imread(image_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype( | |
np.int32 | |
) | |
instances_mask[instances_mask != instance_id] = 0 | |
instances_mask[instances_mask > 0] = 1 | |
return DSample(image, instances_mask, objects_ids=[1], sample_id=index) | |
def get_sbd_images_and_ids_list(self): | |
pkl_path = self.dataset_path / f"{self.dataset_split}_images_and_ids_list.pkl" | |
if pkl_path.exists(): | |
with open(str(pkl_path), "rb") as fp: | |
images_and_ids_list = pkl.load(fp) | |
else: | |
images_and_ids_list = [] | |
for sample in self.dataset_samples: | |
inst_info_path = str(self._insts_path / f"{sample}.mat") | |
instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype( | |
np.int32 | |
) | |
instances_ids, _ = get_labels_with_sizes(instances_mask) | |
for instances_id in instances_ids: | |
images_and_ids_list.append((sample, instances_id)) | |
with open(str(pkl_path), "wb") as fp: | |
pkl.dump(images_and_ids_list, fp) | |
return images_and_ids_list | |