File size: 4,696 Bytes
2cdd41c
 
 
 
 
 
 
 
 
1615d09
2cdd41c
 
 
1615d09
2cdd41c
1615d09
2cdd41c
 
 
1615d09
 
2cdd41c
 
 
1615d09
2cdd41c
 
 
 
1615d09
 
2cdd41c
 
 
1615d09
 
 
2cdd41c
 
 
1615d09
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
2cdd41c
1615d09
2cdd41c
 
 
1615d09
 
2cdd41c
1615d09
2cdd41c
 
 
 
 
 
1615d09
 
2cdd41c
 
 
1615d09
 
 
2cdd41c
 
 
 
 
 
1615d09
2cdd41c
 
1615d09
2cdd41c
 
 
 
 
1615d09
 
 
 
2cdd41c
 
 
 
 
1615d09
2cdd41c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import pickle as pkl
from pathlib import Path

import cv2
import numpy as np
from scipy.io import loadmat

from isegm.data.base import ISDataset
from isegm.data.sample import DSample
from isegm.utils.misc import get_bbox_from_mask, get_labels_with_sizes


class SBDDataset(ISDataset):
    def __init__(self, dataset_path, split="train", buggy_mask_thresh=0.08, **kwargs):
        super(SBDDataset, self).__init__(**kwargs)
        assert split in {"train", "val"}

        self.dataset_path = Path(dataset_path)
        self.dataset_split = split
        self._images_path = self.dataset_path / "img"
        self._insts_path = self.dataset_path / "inst"
        self._buggy_objects = dict()
        self._buggy_mask_thresh = buggy_mask_thresh

        with open(self.dataset_path / f"{split}.txt", "r") as f:
            self.dataset_samples = [x.strip() for x in f.readlines()]

    def get_sample(self, index):
        image_name = self.dataset_samples[index]
        image_path = str(self._images_path / f"{image_name}.jpg")
        inst_info_path = str(self._insts_path / f"{image_name}.mat")

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype(
            np.int32
        )
        instances_mask = self.remove_buggy_masks(index, instances_mask)
        instances_ids, _ = get_labels_with_sizes(instances_mask)

        return DSample(
            image, instances_mask, objects_ids=instances_ids, sample_id=index
        )

    def remove_buggy_masks(self, index, instances_mask):
        if self._buggy_mask_thresh > 0.0:
            buggy_image_objects = self._buggy_objects.get(index, None)
            if buggy_image_objects is None:
                buggy_image_objects = []
                instances_ids, _ = get_labels_with_sizes(instances_mask)
                for obj_id in instances_ids:
                    obj_mask = instances_mask == obj_id
                    mask_area = obj_mask.sum()
                    bbox = get_bbox_from_mask(obj_mask)
                    bbox_area = (bbox[1] - bbox[0] + 1) * (bbox[3] - bbox[2] + 1)
                    obj_area_ratio = mask_area / bbox_area
                    if obj_area_ratio < self._buggy_mask_thresh:
                        buggy_image_objects.append(obj_id)

                self._buggy_objects[index] = buggy_image_objects
            for obj_id in buggy_image_objects:
                instances_mask[instances_mask == obj_id] = 0

        return instances_mask


class SBDEvaluationDataset(ISDataset):
    def __init__(self, dataset_path, split="val", **kwargs):
        super(SBDEvaluationDataset, self).__init__(**kwargs)
        assert split in {"train", "val"}

        self.dataset_path = Path(dataset_path)
        self.dataset_split = split
        self._images_path = self.dataset_path / "img"
        self._insts_path = self.dataset_path / "inst"

        with open(self.dataset_path / f"{split}.txt", "r") as f:
            self.dataset_samples = [x.strip() for x in f.readlines()]

        self.dataset_samples = self.get_sbd_images_and_ids_list()

    def get_sample(self, index) -> DSample:
        image_name, instance_id = self.dataset_samples[index]
        image_path = str(self._images_path / f"{image_name}.jpg")
        inst_info_path = str(self._insts_path / f"{image_name}.mat")

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype(
            np.int32
        )
        instances_mask[instances_mask != instance_id] = 0
        instances_mask[instances_mask > 0] = 1

        return DSample(image, instances_mask, objects_ids=[1], sample_id=index)

    def get_sbd_images_and_ids_list(self):
        pkl_path = self.dataset_path / f"{self.dataset_split}_images_and_ids_list.pkl"

        if pkl_path.exists():
            with open(str(pkl_path), "rb") as fp:
                images_and_ids_list = pkl.load(fp)
        else:
            images_and_ids_list = []

            for sample in self.dataset_samples:
                inst_info_path = str(self._insts_path / f"{sample}.mat")
                instances_mask = loadmat(str(inst_info_path))["GTinst"][0][0][0].astype(
                    np.int32
                )
                instances_ids, _ = get_labels_with_sizes(instances_mask)

                for instances_id in instances_ids:
                    images_and_ids_list.append((sample, instances_id))

            with open(str(pkl_path), "wb") as fp:
                pkl.dump(images_and_ids_list, fp)

        return images_and_ids_list