Spaces:
Runtime error
Runtime error
# Copyright (C) 2021-2024, Mindee. | |
# This program is licensed under the Apache License 2.0. | |
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
from typing import List | |
import cv2 | |
import numpy as np | |
from doctr.utils.repr import NestedObject | |
__all__ = ["DetectionPostProcessor"] | |
class DetectionPostProcessor(NestedObject): | |
"""Abstract class to postprocess the raw output of the model | |
Args: | |
---- | |
box_thresh (float): minimal objectness score to consider a box | |
bin_thresh (float): threshold to apply to segmentation raw heatmap | |
assume straight_pages (bool): if True, fit straight boxes only | |
""" | |
def __init__(self, box_thresh: float = 0.5, bin_thresh: float = 0.5, assume_straight_pages: bool = True) -> None: | |
self.box_thresh = box_thresh | |
self.bin_thresh = bin_thresh | |
self.assume_straight_pages = assume_straight_pages | |
self._opening_kernel: np.ndarray = np.ones((3, 3), dtype=np.uint8) | |
def extra_repr(self) -> str: | |
return f"bin_thresh={self.bin_thresh}, box_thresh={self.box_thresh}" | |
def box_score(pred: np.ndarray, points: np.ndarray, assume_straight_pages: bool = True) -> float: | |
"""Compute the confidence score for a polygon : mean of the p values on the polygon | |
Args: | |
---- | |
pred (np.ndarray): p map returned by the model | |
points: coordinates of the polygon | |
assume_straight_pages: if True, fit straight boxes only | |
Returns: | |
------- | |
polygon objectness | |
""" | |
h, w = pred.shape[:2] | |
if assume_straight_pages: | |
xmin = np.clip(np.floor(points[:, 0].min()).astype(np.int32), 0, w - 1) | |
xmax = np.clip(np.ceil(points[:, 0].max()).astype(np.int32), 0, w - 1) | |
ymin = np.clip(np.floor(points[:, 1].min()).astype(np.int32), 0, h - 1) | |
ymax = np.clip(np.ceil(points[:, 1].max()).astype(np.int32), 0, h - 1) | |
return pred[ymin : ymax + 1, xmin : xmax + 1].mean() | |
else: | |
mask: np.ndarray = np.zeros((h, w), np.int32) | |
cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload] | |
product = pred * mask | |
return np.sum(product) / np.count_nonzero(product) | |
def bitmap_to_boxes( | |
self, | |
pred: np.ndarray, | |
bitmap: np.ndarray, | |
) -> np.ndarray: | |
raise NotImplementedError | |
def __call__( | |
self, | |
proba_map, | |
) -> List[List[np.ndarray]]: | |
"""Performs postprocessing for a list of model outputs | |
Args: | |
---- | |
proba_map: probability map of shape (N, H, W, C) | |
Returns: | |
------- | |
list of N class predictions (for each input sample), where each class predictions is a list of C tensors | |
of shape (*, 5) or (*, 6) | |
""" | |
if proba_map.ndim != 4: | |
raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.") | |
# Erosion + dilation on the binary map | |
bin_map = [ | |
[ | |
cv2.morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel) | |
for idx in range(proba_map.shape[-1]) | |
] | |
for bmap in (proba_map >= self.bin_thresh).astype(np.uint8) | |
] | |
return [ | |
[self.bitmap_to_boxes(pmaps[..., idx], bmaps[idx]) for idx in range(proba_map.shape[-1])] | |
for pmaps, bmaps in zip(proba_map, bin_map) | |
] | |