Spaces:
Runtime error
Runtime error
import os | |
import pickle as pkl | |
import random | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
from isegm.data.base import ISDataset | |
from isegm.data.sample import DSample | |
from isegm.utils.misc import get_labels_with_sizes | |
class ADE20kDataset(ISDataset): | |
def __init__(self, dataset_path, split="train", stuff_prob=0.0, **kwargs): | |
super().__init__(**kwargs) | |
assert split in {"train", "val"} | |
self.dataset_path = Path(dataset_path) | |
self.dataset_split = split | |
self.dataset_split_folder = "training" if split == "train" else "validation" | |
self.stuff_prob = stuff_prob | |
anno_path = self.dataset_path / f"{split}-annotations-object-segmentation.pkl" | |
if os.path.exists(anno_path): | |
with anno_path.open("rb") as f: | |
annotations = pkl.load(f) | |
else: | |
raise RuntimeError(f"Can't find annotations at {anno_path}") | |
self.annotations = annotations | |
self.dataset_samples = list(annotations.keys()) | |
def get_sample(self, index) -> DSample: | |
image_id = self.dataset_samples[index] | |
sample_annos = self.annotations[image_id] | |
image_path = str(self.dataset_path / sample_annos["folder"] / f"{image_id}.jpg") | |
image = cv2.imread(image_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# select random mask for an image | |
layer = random.choice(sample_annos["layers"]) | |
mask_path = str(self.dataset_path / sample_annos["folder"] / layer["mask_name"]) | |
instances_mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)[ | |
:, :, 0 | |
] # the B channel holds instances | |
instances_mask = instances_mask.astype(np.int32) | |
object_ids, _ = get_labels_with_sizes(instances_mask) | |
if (self.stuff_prob <= 0) or (random.random() > self.stuff_prob): | |
# remove stuff objects | |
for i, object_id in enumerate(object_ids): | |
if i in layer["stuff_instances"]: | |
instances_mask[instances_mask == object_id] = 0 | |
object_ids, _ = get_labels_with_sizes(instances_mask) | |
return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index) | |