File size: 3,360 Bytes
2cdd41c
1615d09
 
2cdd41c
 
 
1615d09
2cdd41c
 
 
 
 
1615d09
 
 
 
 
 
 
 
 
 
 
2cdd41c
 
 
 
 
 
 
1615d09
 
 
2cdd41c
 
 
 
 
 
1615d09
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
2cdd41c
 
 
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
2cdd41c
 
 
 
1615d09
 
2cdd41c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import pickle
import random

import numpy as np
import torch
from torchvision import transforms

from .points_sampler import MultiPointSampler
from .sample import DSample


class ISDataset(torch.utils.data.dataset.Dataset):
    def __init__(
        self,
        augmentator=None,
        points_sampler=MultiPointSampler(max_num_points=12),
        min_object_area=0,
        keep_background_prob=0.0,
        with_image_info=False,
        samples_scores_path=None,
        samples_scores_gamma=1.0,
        epoch_len=-1,
    ):
        super(ISDataset, self).__init__()
        self.epoch_len = epoch_len
        self.augmentator = augmentator
        self.min_object_area = min_object_area
        self.keep_background_prob = keep_background_prob
        self.points_sampler = points_sampler
        self.with_image_info = with_image_info
        self.samples_precomputed_scores = self._load_samples_scores(
            samples_scores_path, samples_scores_gamma
        )
        self.to_tensor = transforms.ToTensor()

        self.dataset_samples = None

    def __getitem__(self, index):
        if self.samples_precomputed_scores is not None:
            index = np.random.choice(
                self.samples_precomputed_scores["indices"],
                p=self.samples_precomputed_scores["probs"],
            )
        else:
            if self.epoch_len > 0:
                index = random.randrange(0, len(self.dataset_samples))

        sample = self.get_sample(index)
        sample = self.augment_sample(sample)
        sample.remove_small_objects(self.min_object_area)

        self.points_sampler.sample_object(sample)
        points = np.array(self.points_sampler.sample_points())
        mask = self.points_sampler.selected_mask

        output = {
            "images": self.to_tensor(sample.image),
            "points": points.astype(np.float32),
            "instances": mask,
        }

        if self.with_image_info:
            output["image_info"] = sample.sample_id

        return output

    def augment_sample(self, sample) -> DSample:
        if self.augmentator is None:
            return sample

        valid_augmentation = False
        while not valid_augmentation:
            sample.augment(self.augmentator)
            keep_sample = (
                self.keep_background_prob < 0.0
                or random.random() < self.keep_background_prob
            )
            valid_augmentation = len(sample) > 0 or keep_sample

        return sample

    def get_sample(self, index) -> DSample:
        raise NotImplementedError

    def __len__(self):
        if self.epoch_len > 0:
            return self.epoch_len
        else:
            return self.get_samples_number()

    def get_samples_number(self):
        return len(self.dataset_samples)

    @staticmethod
    def _load_samples_scores(samples_scores_path, samples_scores_gamma):
        if samples_scores_path is None:
            return None

        with open(samples_scores_path, "rb") as f:
            images_scores = pickle.load(f)

        probs = np.array([(1.0 - x[2]) ** samples_scores_gamma for x in images_scores])
        probs /= probs.sum()
        samples_scores = {"indices": [x[0] for x in images_scores], "probs": probs}
        print(f"Loaded {len(probs)} weights with gamma={samples_scores_gamma}")
        return samples_scores