curt-park's picture
Init the space
2cdd41c
raw
history blame
2.71 kB
import cv2
import json
import random
import numpy as np
from pathlib import Path
from isegm.data.base import ISDataset
from isegm.data.sample import DSample
class CocoDataset(ISDataset):
def __init__(self, dataset_path, split='train', stuff_prob=0.0, **kwargs):
super(CocoDataset, self).__init__(**kwargs)
self.split = split
self.dataset_path = Path(dataset_path)
self.stuff_prob = stuff_prob
self.load_samples()
def load_samples(self):
annotation_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}.json'
self.labels_path = self.dataset_path / 'annotations' / f'panoptic_{self.split}'
self.images_path = self.dataset_path / self.split
with open(annotation_path, 'r') as f:
annotation = json.load(f)
self.dataset_samples = annotation['annotations']
self._categories = annotation['categories']
self._stuff_labels = [x['id'] for x in self._categories if x['isthing'] == 0]
self._things_labels = [x['id'] for x in self._categories if x['isthing'] == 1]
self._things_labels_set = set(self._things_labels)
self._stuff_labels_set = set(self._stuff_labels)
def get_sample(self, index) -> DSample:
dataset_sample = self.dataset_samples[index]
image_path = self.images_path / self.get_image_name(dataset_sample['file_name'])
label_path = self.labels_path / dataset_sample['file_name']
image = cv2.imread(str(image_path))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
label = cv2.imread(str(label_path), cv2.IMREAD_UNCHANGED).astype(np.int32)
label = 256 * 256 * label[:, :, 0] + 256 * label[:, :, 1] + label[:, :, 2]
instance_map = np.full_like(label, 0)
things_ids = []
stuff_ids = []
for segment in dataset_sample['segments_info']:
class_id = segment['category_id']
obj_id = segment['id']
if class_id in self._things_labels_set:
if segment['iscrowd'] == 1:
continue
things_ids.append(obj_id)
else:
stuff_ids.append(obj_id)
instance_map[label == obj_id] = obj_id
if self.stuff_prob > 0 and random.random() < self.stuff_prob:
instances_ids = things_ids + stuff_ids
else:
instances_ids = things_ids
for stuff_id in stuff_ids:
instance_map[instance_map == stuff_id] = 0
return DSample(image, instance_map, objects_ids=instances_ids)
@classmethod
def get_image_name(cls, panoptic_name):
return panoptic_name.replace('.png', '.jpg')