adirathor07's picture
added doctr folder
153628e
raw
history blame
3.39 kB
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from typing import List, Tuple, Union
import numpy as np
from ..utils import merge_multi_strings
__all__ = ["split_crops", "remap_preds"]
def split_crops(
crops: List[np.ndarray],
max_ratio: float,
target_ratio: int,
dilation: float,
channels_last: bool = True,
) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]:
"""Chunk crops horizontally to match a given aspect ratio
Args:
----
crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
max_ratio: the maximum aspect ratio that won't trigger the chunk
target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
dilation: the width dilation of final chunks (to provide some overlaps)
channels_last: whether the numpy array has dimensions in channels last order
Returns:
-------
a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
"""
_remap_required = False
crop_map: List[Union[int, Tuple[int, int]]] = []
new_crops: List[np.ndarray] = []
for crop in crops:
h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
aspect_ratio = w / h
if aspect_ratio > max_ratio:
# Determine the number of crops, reference aspect ratio = 4 = 128 / 32
num_subcrops = int(aspect_ratio // target_ratio)
# Find the new widths, additional dilation factor to overlap crops
width = dilation * w / num_subcrops
centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)]
# Get the crops
if channels_last:
_crops = [
crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :]
for center in centers
]
else:
_crops = [
crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))]
for center in centers
]
# Avoid sending zero-sized crops
_crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
# Record the slice of crops
crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
new_crops.extend(_crops)
# At least one crop will require merging
_remap_required = True
else:
crop_map.append(len(new_crops))
new_crops.append(crop)
return new_crops, crop_map, _remap_required
def remap_preds(
preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float
) -> List[Tuple[str, float]]:
remapped_out = []
for _idx in crop_map:
# Crop hasn't been split
if isinstance(_idx, int):
remapped_out.append(preds[_idx])
else:
# unzip
vals, probs = zip(*preds[_idx[0] : _idx[1]])
# Merge the string values
remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type]
return remapped_out