Spaces:
Runtime error
Runtime error
# Copyright (C) 2021-2024, Mindee. | |
# This program is licensed under the Apache License 2.0. | |
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
import math | |
from typing import Any, List, Tuple, Union | |
import numpy as np | |
import tensorflow as tf | |
from doctr.transforms import Normalize, Resize | |
from doctr.utils.multithreading import multithread_exec | |
from doctr.utils.repr import NestedObject | |
__all__ = ["PreProcessor"] | |
class PreProcessor(NestedObject): | |
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. | |
Args: | |
---- | |
output_size: expected size of each page in format (H, W) | |
batch_size: the size of page batches | |
mean: mean value of the training distribution by channel | |
std: standard deviation of the training distribution by channel | |
""" | |
_children_names: List[str] = ["resize", "normalize"] | |
def __init__( | |
self, | |
output_size: Tuple[int, int], | |
batch_size: int, | |
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), | |
std: Tuple[float, float, float] = (1.0, 1.0, 1.0), | |
**kwargs: Any, | |
) -> None: | |
self.batch_size = batch_size | |
self.resize = Resize(output_size, **kwargs) | |
# Perform the division by 255 at the same time | |
self.normalize = Normalize(mean, std) | |
def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]: | |
"""Gather samples into batches for inference purposes | |
Args: | |
---- | |
samples: list of samples (tf.Tensor) | |
Returns: | |
------- | |
list of batched samples | |
""" | |
num_batches = int(math.ceil(len(samples) / self.batch_size)) | |
batches = [ | |
tf.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0) | |
for idx in range(int(num_batches)) | |
] | |
return batches | |
def sample_transforms(self, x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor: | |
if x.ndim != 3: | |
raise AssertionError("expected list of 3D Tensors") | |
if isinstance(x, np.ndarray): | |
if x.dtype not in (np.uint8, np.float32): | |
raise TypeError("unsupported data type for numpy.ndarray") | |
x = tf.convert_to_tensor(x) | |
elif x.dtype not in (tf.uint8, tf.float16, tf.float32): | |
raise TypeError("unsupported data type for torch.Tensor") | |
# Data type & 255 division | |
if x.dtype == tf.uint8: | |
x = tf.image.convert_image_dtype(x, dtype=tf.float32) | |
# Resizing | |
x = self.resize(x) | |
return x | |
def __call__(self, x: Union[tf.Tensor, np.ndarray, List[Union[tf.Tensor, np.ndarray]]]) -> List[tf.Tensor]: | |
"""Prepare document data for model forwarding | |
Args: | |
---- | |
x: list of images (np.array) or tensors (already resized and batched) | |
Returns: | |
------- | |
list of page batches | |
""" | |
# Input type check | |
if isinstance(x, (np.ndarray, tf.Tensor)): | |
if x.ndim != 4: | |
raise AssertionError("expected 4D Tensor") | |
if isinstance(x, np.ndarray): | |
if x.dtype not in (np.uint8, np.float32): | |
raise TypeError("unsupported data type for numpy.ndarray") | |
x = tf.convert_to_tensor(x) | |
elif x.dtype not in (tf.uint8, tf.float16, tf.float32): | |
raise TypeError("unsupported data type for torch.Tensor") | |
# Data type & 255 division | |
if x.dtype == tf.uint8: | |
x = tf.image.convert_image_dtype(x, dtype=tf.float32) | |
# Resizing | |
if (x.shape[1], x.shape[2]) != self.resize.output_size: | |
x = tf.image.resize( | |
x, self.resize.output_size, method=self.resize.method, antialias=self.resize.antialias | |
) | |
batches = [x] | |
elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x): | |
# Sample transform (to tensor, resize) | |
samples = list(multithread_exec(self.sample_transforms, x)) | |
# Batching | |
batches = self.batch_inputs(samples) | |
else: | |
raise TypeError(f"invalid input type: {type(x)}") | |
# Batch transforms (normalize) | |
batches = list(multithread_exec(self.normalize, batches)) | |
return batches | |