curt-park's picture
Refactor code
1615d09
import json
import pickle
import random
from copy import deepcopy
from pathlib import Path
import cv2
import numpy as np
from isegm.data.base import ISDataset
from isegm.data.sample import DSample
class CocoLvisDataset(ISDataset):
def __init__(
self,
dataset_path,
split="train",
stuff_prob=0.0,
allow_list_name=None,
anno_file="hannotation.pickle",
**kwargs,
):
super(CocoLvisDataset, self).__init__(**kwargs)
dataset_path = Path(dataset_path)
self._split_path = dataset_path / split
self.split = split
self._images_path = self._split_path / "images"
self._masks_path = self._split_path / "masks"
self.stuff_prob = stuff_prob
with open(self._split_path / anno_file, "rb") as f:
self.dataset_samples = sorted(pickle.load(f).items())
if allow_list_name is not None:
allow_list_path = self._split_path / allow_list_name
with open(allow_list_path, "r") as f:
allow_images_ids = json.load(f)
allow_images_ids = set(allow_images_ids)
self.dataset_samples = [
sample
for sample in self.dataset_samples
if sample[0] in allow_images_ids
]
def get_sample(self, index) -> DSample:
image_id, sample = self.dataset_samples[index]
image_path = self._images_path / f"{image_id}.jpg"
image = cv2.imread(str(image_path))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
packed_masks_path = self._masks_path / f"{image_id}.pickle"
with open(packed_masks_path, "rb") as f:
encoded_layers, objs_mapping = pickle.load(f)
layers = [cv2.imdecode(x, cv2.IMREAD_UNCHANGED) for x in encoded_layers]
layers = np.stack(layers, axis=2)
instances_info = deepcopy(sample["hierarchy"])
for inst_id, inst_info in list(instances_info.items()):
if inst_info is None:
inst_info = {"children": [], "parent": None, "node_level": 0}
instances_info[inst_id] = inst_info
inst_info["mapping"] = objs_mapping[inst_id]
if self.stuff_prob > 0 and random.random() < self.stuff_prob:
for inst_id in range(sample["num_instance_masks"], len(objs_mapping)):
instances_info[inst_id] = {
"mapping": objs_mapping[inst_id],
"parent": None,
"children": [],
}
else:
for inst_id in range(sample["num_instance_masks"], len(objs_mapping)):
layer_indx, mask_id = objs_mapping[inst_id]
layers[:, :, layer_indx][layers[:, :, layer_indx] == mask_id] = 0
return DSample(image, layers, objects=instances_info)