adirathor07's picture
added doctr folder
153628e
raw
history blame
4.48 kB
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import math
from typing import Any, List, Tuple, Union
import numpy as np
import tensorflow as tf
from doctr.transforms import Normalize, Resize
from doctr.utils.multithreading import multithread_exec
from doctr.utils.repr import NestedObject
__all__ = ["PreProcessor"]
class PreProcessor(NestedObject):
"""Implements an abstract preprocessor object which performs casting, resizing, batching and normalization.
Args:
----
output_size: expected size of each page in format (H, W)
batch_size: the size of page batches
mean: mean value of the training distribution by channel
std: standard deviation of the training distribution by channel
"""
_children_names: List[str] = ["resize", "normalize"]
def __init__(
self,
output_size: Tuple[int, int],
batch_size: int,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (1.0, 1.0, 1.0),
**kwargs: Any,
) -> None:
self.batch_size = batch_size
self.resize = Resize(output_size, **kwargs)
# Perform the division by 255 at the same time
self.normalize = Normalize(mean, std)
def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]:
"""Gather samples into batches for inference purposes
Args:
----
samples: list of samples (tf.Tensor)
Returns:
-------
list of batched samples
"""
num_batches = int(math.ceil(len(samples) / self.batch_size))
batches = [
tf.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0)
for idx in range(int(num_batches))
]
return batches
def sample_transforms(self, x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor:
if x.ndim != 3:
raise AssertionError("expected list of 3D Tensors")
if isinstance(x, np.ndarray):
if x.dtype not in (np.uint8, np.float32):
raise TypeError("unsupported data type for numpy.ndarray")
x = tf.convert_to_tensor(x)
elif x.dtype not in (tf.uint8, tf.float16, tf.float32):
raise TypeError("unsupported data type for torch.Tensor")
# Data type & 255 division
if x.dtype == tf.uint8:
x = tf.image.convert_image_dtype(x, dtype=tf.float32)
# Resizing
x = self.resize(x)
return x
def __call__(self, x: Union[tf.Tensor, np.ndarray, List[Union[tf.Tensor, np.ndarray]]]) -> List[tf.Tensor]:
"""Prepare document data for model forwarding
Args:
----
x: list of images (np.array) or tensors (already resized and batched)
Returns:
-------
list of page batches
"""
# Input type check
if isinstance(x, (np.ndarray, tf.Tensor)):
if x.ndim != 4:
raise AssertionError("expected 4D Tensor")
if isinstance(x, np.ndarray):
if x.dtype not in (np.uint8, np.float32):
raise TypeError("unsupported data type for numpy.ndarray")
x = tf.convert_to_tensor(x)
elif x.dtype not in (tf.uint8, tf.float16, tf.float32):
raise TypeError("unsupported data type for torch.Tensor")
# Data type & 255 division
if x.dtype == tf.uint8:
x = tf.image.convert_image_dtype(x, dtype=tf.float32)
# Resizing
if (x.shape[1], x.shape[2]) != self.resize.output_size:
x = tf.image.resize(
x, self.resize.output_size, method=self.resize.method, antialias=self.resize.antialias
)
batches = [x]
elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x):
# Sample transform (to tensor, resize)
samples = list(multithread_exec(self.sample_transforms, x))
# Batching
batches = self.batch_inputs(samples)
else:
raise TypeError(f"invalid input type: {type(x)}")
# Batch transforms (normalize)
batches = list(multithread_exec(self.normalize, batches))
return batches