Spaces:
Runtime error
Runtime error
# Copyright (C) 2021-2024, Mindee. | |
# This program is licensed under the Apache License 2.0. | |
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
from typing import List, Tuple, Union | |
import numpy as np | |
from ..utils import merge_multi_strings | |
__all__ = ["split_crops", "remap_preds"] | |
def split_crops( | |
crops: List[np.ndarray], | |
max_ratio: float, | |
target_ratio: int, | |
dilation: float, | |
channels_last: bool = True, | |
) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]: | |
"""Chunk crops horizontally to match a given aspect ratio | |
Args: | |
---- | |
crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise | |
max_ratio: the maximum aspect ratio that won't trigger the chunk | |
target_ratio: when crops are chunked, they will be chunked to match this aspect ratio | |
dilation: the width dilation of final chunks (to provide some overlaps) | |
channels_last: whether the numpy array has dimensions in channels last order | |
Returns: | |
------- | |
a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required | |
""" | |
_remap_required = False | |
crop_map: List[Union[int, Tuple[int, int]]] = [] | |
new_crops: List[np.ndarray] = [] | |
for crop in crops: | |
h, w = crop.shape[:2] if channels_last else crop.shape[-2:] | |
aspect_ratio = w / h | |
if aspect_ratio > max_ratio: | |
# Determine the number of crops, reference aspect ratio = 4 = 128 / 32 | |
num_subcrops = int(aspect_ratio // target_ratio) | |
# Find the new widths, additional dilation factor to overlap crops | |
width = dilation * w / num_subcrops | |
centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)] | |
# Get the crops | |
if channels_last: | |
_crops = [ | |
crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :] | |
for center in centers | |
] | |
else: | |
_crops = [ | |
crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))] | |
for center in centers | |
] | |
# Avoid sending zero-sized crops | |
_crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)] | |
# Record the slice of crops | |
crop_map.append((len(new_crops), len(new_crops) + len(_crops))) | |
new_crops.extend(_crops) | |
# At least one crop will require merging | |
_remap_required = True | |
else: | |
crop_map.append(len(new_crops)) | |
new_crops.append(crop) | |
return new_crops, crop_map, _remap_required | |
def remap_preds( | |
preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float | |
) -> List[Tuple[str, float]]: | |
remapped_out = [] | |
for _idx in crop_map: | |
# Crop hasn't been split | |
if isinstance(_idx, int): | |
remapped_out.append(preds[_idx]) | |
else: | |
# unzip | |
vals, probs = zip(*preds[_idx[0] : _idx[1]]) | |
# Merge the string values | |
remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type] | |
return remapped_out | |