Spaces:
Runtime error
Runtime error
File size: 2,785 Bytes
153628e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from typing import Any, List, Sequence, Tuple, Union
import numpy as np
import torch
from torch import nn
from doctr.models.preprocessor import PreProcessor
from doctr.models.utils import set_device_and_dtype
from ._utils import remap_preds, split_crops
__all__ = ["RecognitionPredictor"]
class RecognitionPredictor(nn.Module):
"""Implements an object able to identify character sequences in images
Args:
----
pre_processor: transform inputs for easier batched model inference
model: core detection architecture
split_wide_crops: wether to use crop splitting for high aspect ratio crops
"""
def __init__(
self,
pre_processor: PreProcessor,
model: nn.Module,
split_wide_crops: bool = True,
) -> None:
super().__init__()
self.pre_processor = pre_processor
self.model = model.eval()
self.split_wide_crops = split_wide_crops
self.critical_ar = 8 # Critical aspect ratio
self.dil_factor = 1.4 # Dilation factor to overlap the crops
self.target_ar = 6 # Target aspect ratio
@torch.inference_mode()
def forward(
self,
crops: Sequence[Union[np.ndarray, torch.Tensor]],
**kwargs: Any,
) -> List[Tuple[str, float]]:
if len(crops) == 0:
return []
# Dimension check
if any(crop.ndim != 3 for crop in crops):
raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
# Split crops that are too wide
remapped = False
if self.split_wide_crops:
new_crops, crop_map, remapped = split_crops(
crops, # type: ignore[arg-type]
self.critical_ar,
self.target_ar,
self.dil_factor,
isinstance(crops[0], np.ndarray),
)
if remapped:
crops = new_crops
# Resize & batch them
processed_batches = self.pre_processor(crops)
# Forward it
_params = next(self.model.parameters())
self.model, processed_batches = set_device_and_dtype(
self.model, processed_batches, _params.device, _params.dtype
)
raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches]
# Process outputs
out = [charseq for batch in raw for charseq in batch]
# Remap crops
if self.split_wide_crops and remapped:
out = remap_preds(out, crop_map, self.dil_factor)
return out
|