Spaces:
Runtime error
Runtime error
import json | |
import random | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
from isegm.data.base import ISDataset | |
from isegm.data.sample import DSample | |
class CocoDataset(ISDataset): | |
def __init__(self, dataset_path, split="train", stuff_prob=0.0, **kwargs): | |
super(CocoDataset, self).__init__(**kwargs) | |
self.split = split | |
self.dataset_path = Path(dataset_path) | |
self.stuff_prob = stuff_prob | |
self.load_samples() | |
def load_samples(self): | |
annotation_path = ( | |
self.dataset_path / "annotations" / f"panoptic_{self.split}.json" | |
) | |
self.labels_path = self.dataset_path / "annotations" / f"panoptic_{self.split}" | |
self.images_path = self.dataset_path / self.split | |
with open(annotation_path, "r") as f: | |
annotation = json.load(f) | |
self.dataset_samples = annotation["annotations"] | |
self._categories = annotation["categories"] | |
self._stuff_labels = [x["id"] for x in self._categories if x["isthing"] == 0] | |
self._things_labels = [x["id"] for x in self._categories if x["isthing"] == 1] | |
self._things_labels_set = set(self._things_labels) | |
self._stuff_labels_set = set(self._stuff_labels) | |
def get_sample(self, index) -> DSample: | |
dataset_sample = self.dataset_samples[index] | |
image_path = self.images_path / self.get_image_name(dataset_sample["file_name"]) | |
label_path = self.labels_path / dataset_sample["file_name"] | |
image = cv2.imread(str(image_path)) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
label = cv2.imread(str(label_path), cv2.IMREAD_UNCHANGED).astype(np.int32) | |
label = 256 * 256 * label[:, :, 0] + 256 * label[:, :, 1] + label[:, :, 2] | |
instance_map = np.full_like(label, 0) | |
things_ids = [] | |
stuff_ids = [] | |
for segment in dataset_sample["segments_info"]: | |
class_id = segment["category_id"] | |
obj_id = segment["id"] | |
if class_id in self._things_labels_set: | |
if segment["iscrowd"] == 1: | |
continue | |
things_ids.append(obj_id) | |
else: | |
stuff_ids.append(obj_id) | |
instance_map[label == obj_id] = obj_id | |
if self.stuff_prob > 0 and random.random() < self.stuff_prob: | |
instances_ids = things_ids + stuff_ids | |
else: | |
instances_ids = things_ids | |
for stuff_id in stuff_ids: | |
instance_map[instance_map == stuff_id] = 0 | |
return DSample(image, instance_map, objects_ids=instances_ids) | |
def get_image_name(cls, panoptic_name): | |
return panoptic_name.replace(".png", ".jpg") | |