import os import pickle as pkl import random from pathlib import Path import cv2 import numpy as np from isegm.data.base import ISDataset from isegm.data.sample import DSample class OpenImagesDataset(ISDataset): def __init__(self, dataset_path, split="train", **kwargs): super().__init__(**kwargs) assert split in {"train", "val", "test"} self.dataset_path = Path(dataset_path) self._split_path = self.dataset_path / split self._images_path = self._split_path / "images" self._masks_path = self._split_path / "masks" self.dataset_split = split clean_anno_path = ( self._split_path / f"{split}-annotations-object-segmentation_clean.pkl" ) if os.path.exists(clean_anno_path): with clean_anno_path.open("rb") as f: annotations = pkl.load(f) else: raise RuntimeError(f"Can't find annotations at {clean_anno_path}") self.image_id_to_masks = annotations["image_id_to_masks"] self.dataset_samples = annotations["dataset_samples"] def get_sample(self, index) -> DSample: image_id = self.dataset_samples[index] image_path = str(self._images_path / f"{image_id}.jpg") image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask_paths = self.image_id_to_masks[image_id] # select random mask for an image mask_path = str(self._masks_path / random.choice(mask_paths)) instances_mask = cv2.imread(mask_path) instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY) instances_mask[instances_mask > 0] = 1 instances_mask = instances_mask.astype(np.int32) min_width = min(image.shape[1], instances_mask.shape[1]) min_height = min(image.shape[0], instances_mask.shape[0]) if image.shape[0] != min_height or image.shape[1] != min_width: image = cv2.resize( image, (min_width, min_height), interpolation=cv2.INTER_LINEAR ) if ( instances_mask.shape[0] != min_height or instances_mask.shape[1] != min_width ): instances_mask = cv2.resize( instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST ) object_ids = [1] if instances_mask.sum() > 0 else [] return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index)