Spaces:
Runtime error
Runtime error
# Copyright (C) 2021-2024, Mindee. | |
# This program is licensed under the Apache License 2.0. | |
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
import math | |
from typing import Any, List, Tuple, Union | |
import numpy as np | |
import torch | |
from torch import nn | |
from torchvision.transforms import functional as F | |
from torchvision.transforms import transforms as T | |
from doctr.transforms import Resize | |
from doctr.utils.multithreading import multithread_exec | |
__all__ = ["PreProcessor"] | |
class PreProcessor(nn.Module): | |
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. | |
Args: | |
---- | |
output_size: expected size of each page in format (H, W) | |
batch_size: the size of page batches | |
mean: mean value of the training distribution by channel | |
std: standard deviation of the training distribution by channel | |
""" | |
def __init__( | |
self, | |
output_size: Tuple[int, int], | |
batch_size: int, | |
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), | |
std: Tuple[float, float, float] = (1.0, 1.0, 1.0), | |
**kwargs: Any, | |
) -> None: | |
super().__init__() | |
self.batch_size = batch_size | |
self.resize: T.Resize = Resize(output_size, **kwargs) | |
# Perform the division by 255 at the same time | |
self.normalize = T.Normalize(mean, std) | |
def batch_inputs(self, samples: List[torch.Tensor]) -> List[torch.Tensor]: | |
"""Gather samples into batches for inference purposes | |
Args: | |
---- | |
samples: list of samples of shape (C, H, W) | |
Returns: | |
------- | |
list of batched samples (*, C, H, W) | |
""" | |
num_batches = int(math.ceil(len(samples) / self.batch_size)) | |
batches = [ | |
torch.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], dim=0) | |
for idx in range(int(num_batches)) | |
] | |
return batches | |
def sample_transforms(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: | |
if x.ndim != 3: | |
raise AssertionError("expected list of 3D Tensors") | |
if isinstance(x, np.ndarray): | |
if x.dtype not in (np.uint8, np.float32): | |
raise TypeError("unsupported data type for numpy.ndarray") | |
x = torch.from_numpy(x.copy()).permute(2, 0, 1) | |
elif x.dtype not in (torch.uint8, torch.float16, torch.float32): | |
raise TypeError("unsupported data type for torch.Tensor") | |
# Resizing | |
x = self.resize(x) | |
# Data type | |
if x.dtype == torch.uint8: | |
x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr] | |
else: | |
x = x.to(dtype=torch.float32) # type: ignore[union-attr] | |
return x | |
def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]: | |
"""Prepare document data for model forwarding | |
Args: | |
---- | |
x: list of images (np.array) or tensors (already resized and batched) | |
Returns: | |
------- | |
list of page batches | |
""" | |
# Input type check | |
if isinstance(x, (np.ndarray, torch.Tensor)): | |
if x.ndim != 4: | |
raise AssertionError("expected 4D Tensor") | |
if isinstance(x, np.ndarray): | |
if x.dtype not in (np.uint8, np.float32): | |
raise TypeError("unsupported data type for numpy.ndarray") | |
x = torch.from_numpy(x.copy()).permute(0, 3, 1, 2) | |
elif x.dtype not in (torch.uint8, torch.float16, torch.float32): | |
raise TypeError("unsupported data type for torch.Tensor") | |
# Resizing | |
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: | |
x = F.resize( | |
x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias | |
) | |
# Data type | |
if x.dtype == torch.uint8: # type: ignore[union-attr] | |
x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr] | |
else: | |
x = x.to(dtype=torch.float32) # type: ignore[union-attr] | |
batches = [x] | |
elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, torch.Tensor)) for sample in x): | |
# Sample transform (to tensor, resize) | |
samples = list(multithread_exec(self.sample_transforms, x)) | |
# Batching | |
batches = self.batch_inputs(samples) | |
else: | |
raise TypeError(f"invalid input type: {type(x)}") | |
# Batch transforms (normalize) | |
batches = list(multithread_exec(self.normalize, batches)) | |
return batches | |