File size: 3,406 Bytes
2cdd41c
1615d09
2cdd41c
 
1615d09
2cdd41c
 
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
2cdd41c
 
1615d09
 
 
2cdd41c
 
 
 
 
1615d09
 
 
 
2cdd41c
 
 
 
 
 
 
 
1615d09
 
 
2cdd41c
 
 
 
1615d09
 
 
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import math
from typing import List

import numpy as np
import torch

from isegm.inference.clicker import Click

from .base import BaseTransform


class Crops(BaseTransform):
    def __init__(self, crop_size=(320, 480), min_overlap=0.2):
        super().__init__()
        self.crop_height, self.crop_width = crop_size
        self.min_overlap = min_overlap

        self.x_offsets = None
        self.y_offsets = None
        self._counts = None

    def transform(self, image_nd, clicks_lists: List[List[Click]]):
        assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
        image_height, image_width = image_nd.shape[2:4]
        self._counts = None

        if image_height < self.crop_height or image_width < self.crop_width:
            return image_nd, clicks_lists

        self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap)
        self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap)
        self._counts = np.zeros((image_height, image_width))

        image_crops = []
        for dy in self.y_offsets:
            for dx in self.x_offsets:
                self._counts[dy : dy + self.crop_height, dx : dx + self.crop_width] += 1
                image_crop = image_nd[
                    :, :, dy : dy + self.crop_height, dx : dx + self.crop_width
                ]
                image_crops.append(image_crop)
        image_crops = torch.cat(image_crops, dim=0)
        self._counts = torch.tensor(
            self._counts, device=image_nd.device, dtype=torch.float32
        )

        clicks_list = clicks_lists[0]
        clicks_lists = []
        for dy in self.y_offsets:
            for dx in self.x_offsets:
                crop_clicks = [
                    x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx))
                    for x in clicks_list
                ]
                clicks_lists.append(crop_clicks)

        return image_crops, clicks_lists

    def inv_transform(self, prob_map):
        if self._counts is None:
            return prob_map

        new_prob_map = torch.zeros(
            (1, 1, *self._counts.shape), dtype=prob_map.dtype, device=prob_map.device
        )

        crop_indx = 0
        for dy in self.y_offsets:
            for dx in self.x_offsets:
                new_prob_map[
                    0, 0, dy : dy + self.crop_height, dx : dx + self.crop_width
                ] += prob_map[crop_indx, 0]
                crop_indx += 1
        new_prob_map = torch.div(new_prob_map, self._counts)

        return new_prob_map

    def get_state(self):
        return self.x_offsets, self.y_offsets, self._counts

    def set_state(self, state):
        self.x_offsets, self.y_offsets, self._counts = state

    def reset(self):
        self.x_offsets = None
        self.y_offsets = None
        self._counts = None


def get_offsets(length, crop_size, min_overlap_ratio=0.2):
    if length == crop_size:
        return [0]

    N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio)
    N = math.ceil(N)

    overlap_ratio = (N - length / crop_size) / (N - 1)
    overlap_width = int(crop_size * overlap_ratio)

    offsets = [0]
    for i in range(1, N):
        new_offset = offsets[-1] + crop_size - overlap_width
        if new_offset + crop_size > length:
            new_offset = length - crop_size

        offsets.append(new_offset)

    return offsets