# Copyright (C) 2021-2024, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. from typing import List, Tuple, Union import numpy as np from ..utils import merge_multi_strings __all__ = ["split_crops", "remap_preds"] def split_crops( crops: List[np.ndarray], max_ratio: float, target_ratio: int, dilation: float, channels_last: bool = True, ) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]: """Chunk crops horizontally to match a given aspect ratio Args: ---- crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise max_ratio: the maximum aspect ratio that won't trigger the chunk target_ratio: when crops are chunked, they will be chunked to match this aspect ratio dilation: the width dilation of final chunks (to provide some overlaps) channels_last: whether the numpy array has dimensions in channels last order Returns: ------- a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required """ _remap_required = False crop_map: List[Union[int, Tuple[int, int]]] = [] new_crops: List[np.ndarray] = [] for crop in crops: h, w = crop.shape[:2] if channels_last else crop.shape[-2:] aspect_ratio = w / h if aspect_ratio > max_ratio: # Determine the number of crops, reference aspect ratio = 4 = 128 / 32 num_subcrops = int(aspect_ratio // target_ratio) # Find the new widths, additional dilation factor to overlap crops width = dilation * w / num_subcrops centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)] # Get the crops if channels_last: _crops = [ crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :] for center in centers ] else: _crops = [ crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))] for center in centers ] # Avoid sending zero-sized crops _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)] # Record the slice of crops crop_map.append((len(new_crops), len(new_crops) + len(_crops))) new_crops.extend(_crops) # At least one crop will require merging _remap_required = True else: crop_map.append(len(new_crops)) new_crops.append(crop) return new_crops, crop_map, _remap_required def remap_preds( preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float ) -> List[Tuple[str, float]]: remapped_out = [] for _idx in crop_map: # Crop hasn't been split if isinstance(_idx, int): remapped_out.append(preds[_idx]) else: # unzip vals, probs = zip(*preds[_idx[0] : _idx[1]]) # Merge the string values remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type] return remapped_out