File size: 2,482 Bytes
2cdd41c
 
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
1615d09
2cdd41c
1615d09
2cdd41c
 
 
1615d09
 
2cdd41c
 
1615d09
 
 
2cdd41c
1615d09
2cdd41c
 
 
1615d09
 
2cdd41c
 
 
 
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
 
 
 
 
 
 
2cdd41c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import pickle as pkl
import random
from pathlib import Path

import cv2
import numpy as np

from isegm.data.base import ISDataset
from isegm.data.sample import DSample


class OpenImagesDataset(ISDataset):
    def __init__(self, dataset_path, split="train", **kwargs):
        super().__init__(**kwargs)
        assert split in {"train", "val", "test"}

        self.dataset_path = Path(dataset_path)
        self._split_path = self.dataset_path / split
        self._images_path = self._split_path / "images"
        self._masks_path = self._split_path / "masks"
        self.dataset_split = split

        clean_anno_path = (
            self._split_path / f"{split}-annotations-object-segmentation_clean.pkl"
        )
        if os.path.exists(clean_anno_path):
            with clean_anno_path.open("rb") as f:
                annotations = pkl.load(f)
        else:
            raise RuntimeError(f"Can't find annotations at {clean_anno_path}")
        self.image_id_to_masks = annotations["image_id_to_masks"]
        self.dataset_samples = annotations["dataset_samples"]

    def get_sample(self, index) -> DSample:
        image_id = self.dataset_samples[index]

        image_path = str(self._images_path / f"{image_id}.jpg")
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask_paths = self.image_id_to_masks[image_id]
        # select random mask for an image
        mask_path = str(self._masks_path / random.choice(mask_paths))
        instances_mask = cv2.imread(mask_path)
        instances_mask = cv2.cvtColor(instances_mask, cv2.COLOR_BGR2GRAY)
        instances_mask[instances_mask > 0] = 1
        instances_mask = instances_mask.astype(np.int32)

        min_width = min(image.shape[1], instances_mask.shape[1])
        min_height = min(image.shape[0], instances_mask.shape[0])

        if image.shape[0] != min_height or image.shape[1] != min_width:
            image = cv2.resize(
                image, (min_width, min_height), interpolation=cv2.INTER_LINEAR
            )
        if (
            instances_mask.shape[0] != min_height
            or instances_mask.shape[1] != min_width
        ):
            instances_mask = cv2.resize(
                instances_mask, (min_width, min_height), interpolation=cv2.INTER_NEAREST
            )

        object_ids = [1] if instances_mask.sum() > 0 else []

        return DSample(image, instances_mask, objects_ids=object_ids, sample_id=index)