File size: 3,389 Bytes
153628e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import List, Tuple, Union

import numpy as np

from ..utils import merge_multi_strings

__all__ = ["split_crops", "remap_preds"]


def split_crops(
    crops: List[np.ndarray],
    max_ratio: float,
    target_ratio: int,
    dilation: float,
    channels_last: bool = True,
) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]:
    """Chunk crops horizontally to match a given aspect ratio

    Args:
    ----
        crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
        max_ratio: the maximum aspect ratio that won't trigger the chunk
        target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
        dilation: the width dilation of final chunks (to provide some overlaps)
        channels_last: whether the numpy array has dimensions in channels last order

    Returns:
    -------
        a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
    """
    _remap_required = False
    crop_map: List[Union[int, Tuple[int, int]]] = []
    new_crops: List[np.ndarray] = []
    for crop in crops:
        h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
        aspect_ratio = w / h
        if aspect_ratio > max_ratio:
            # Determine the number of crops, reference aspect ratio = 4 = 128 / 32
            num_subcrops = int(aspect_ratio // target_ratio)
            # Find the new widths, additional dilation factor to overlap crops
            width = dilation * w / num_subcrops
            centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)]
            # Get the crops
            if channels_last:
                _crops = [
                    crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :]
                    for center in centers
                ]
            else:
                _crops = [
                    crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))]
                    for center in centers
                ]
            # Avoid sending zero-sized crops
            _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
            # Record the slice of crops
            crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
            new_crops.extend(_crops)
            # At least one crop will require merging
            _remap_required = True
        else:
            crop_map.append(len(new_crops))
            new_crops.append(crop)

    return new_crops, crop_map, _remap_required


def remap_preds(
    preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float
) -> List[Tuple[str, float]]:
    remapped_out = []
    for _idx in crop_map:
        # Crop hasn't been split
        if isinstance(_idx, int):
            remapped_out.append(preds[_idx])
        else:
            # unzip
            vals, probs = zip(*preds[_idx[0] : _idx[1]])
            # Merge the string values
            remapped_out.append((merge_multi_strings(vals, dilation), min(probs)))  # type: ignore[arg-type]
    return remapped_out