import json import random from pathlib import Path import cv2 import numpy as np from isegm.data.base import ISDataset from isegm.data.sample import DSample class CocoDataset(ISDataset): def __init__(self, dataset_path, split="train", stuff_prob=0.0, **kwargs): super(CocoDataset, self).__init__(**kwargs) self.split = split self.dataset_path = Path(dataset_path) self.stuff_prob = stuff_prob self.load_samples() def load_samples(self): annotation_path = ( self.dataset_path / "annotations" / f"panoptic_{self.split}.json" ) self.labels_path = self.dataset_path / "annotations" / f"panoptic_{self.split}" self.images_path = self.dataset_path / self.split with open(annotation_path, "r") as f: annotation = json.load(f) self.dataset_samples = annotation["annotations"] self._categories = annotation["categories"] self._stuff_labels = [x["id"] for x in self._categories if x["isthing"] == 0] self._things_labels = [x["id"] for x in self._categories if x["isthing"] == 1] self._things_labels_set = set(self._things_labels) self._stuff_labels_set = set(self._stuff_labels) def get_sample(self, index) -> DSample: dataset_sample = self.dataset_samples[index] image_path = self.images_path / self.get_image_name(dataset_sample["file_name"]) label_path = self.labels_path / dataset_sample["file_name"] image = cv2.imread(str(image_path)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) label = cv2.imread(str(label_path), cv2.IMREAD_UNCHANGED).astype(np.int32) label = 256 * 256 * label[:, :, 0] + 256 * label[:, :, 1] + label[:, :, 2] instance_map = np.full_like(label, 0) things_ids = [] stuff_ids = [] for segment in dataset_sample["segments_info"]: class_id = segment["category_id"] obj_id = segment["id"] if class_id in self._things_labels_set: if segment["iscrowd"] == 1: continue things_ids.append(obj_id) else: stuff_ids.append(obj_id) instance_map[label == obj_id] = obj_id if self.stuff_prob > 0 and random.random() < self.stuff_prob: instances_ids = things_ids + stuff_ids else: instances_ids = things_ids for stuff_id in stuff_ids: instance_map[instance_map == stuff_id] = 0 return DSample(image, instance_map, objects_ids=instances_ids) @classmethod def get_image_name(cls, panoptic_name): return panoptic_name.replace(".png", ".jpg")