|
from typing import List |
|
import numpy as np |
|
import torch |
|
|
|
from transformers.utils import is_tf_available |
|
|
|
if is_tf_available(): |
|
import tensorflow as tf |
|
else: |
|
raise ValueError("Please run `pip install tensorflow` to use the processor.") |
|
|
|
MEAN_RGB = [0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255] |
|
STDDEV_RGB = [0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255] |
|
|
|
|
|
def crop_image(image: tf.Tensor, center_crop_fraction: float = 0.875): |
|
image_size = tf.cast(tf.shape(image)[:2], dtype=tf.float32) |
|
crop_size = center_crop_fraction * tf.math.minimum(image_size[0], image_size[1]) |
|
crop_offset = tf.cast((image_size - crop_size) / 2.0, dtype=tf.int32) |
|
crop_size = tf.cast(crop_size, dtype=tf.int32) |
|
return image[ |
|
crop_offset[0] : crop_offset[0] + crop_size, crop_offset[1] : crop_offset[1] + crop_size, : |
|
] |
|
|
|
|
|
def whiten( |
|
image: tf.Tensor, |
|
) -> tf.Tensor: |
|
image = tf.cast(tf.convert_to_tensor(image), tf.float32) |
|
image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype) |
|
image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype) |
|
return image |
|
|
|
|
|
def tf_image_reshape_crop(image: tf.Tensor, crop_size: int) -> tf.Tensor: |
|
|
|
shape_1 = (100000, crop_size) |
|
shape_2 = (crop_size, 100000) |
|
image = tf.cond( |
|
tf.shape(image)[0] > tf.shape(image)[1], |
|
lambda: tf.image.resize(image, shape_1, method="bilinear", preserve_aspect_ratio=True, antialias=False), |
|
lambda: tf.image.resize(image, shape_2, method="bilinear", preserve_aspect_ratio=True, antialias=False), |
|
) |
|
processed_image = crop_image(image=image, center_crop_fraction=1) |
|
return processed_image |
|
|
|
|
|
def _single_image_preprocess(image: np.ndarray, crop_size: int = 224, resize_only: bool = False): |
|
"""Single image preprocess. |
|
Args: |
|
images: image in numpy array. |
|
crop_size: the size of the cropped images. |
|
resize_only: If true, only resize to the crop size, otherwise, first resize then center crop. |
|
Returns: |
|
A torch tensor with processed image. |
|
""" |
|
image = tf.constant(image) |
|
if resize_only: |
|
image = tf.image.resize( |
|
image, (crop_size, crop_size), method="bilinear", preserve_aspect_ratio=False, antialias=False |
|
) |
|
else: |
|
image = tf_image_reshape_crop(image, crop_size) |
|
image = whiten(image) |
|
return torch.asarray(image.numpy()) |
|
|
|
|
|
def image_preprocess(images: List[np.ndarray], crop_size: int = 224, resize_only: bool = False): |
|
"""Image preprocess using tf resizing function. |
|
Args: |
|
images: A list of numpy array. |
|
crop_size: the size of the cropped images. |
|
Returns: |
|
A torch tensor with shape [size_of_images, crop_size, crop_size, 3]. |
|
""" |
|
processed_images = [] |
|
for image in images: |
|
image = tf.constant(image) |
|
processed_image = _single_image_preprocess(image, crop_size=crop_size, resize_only=resize_only) |
|
processed_images.append(processed_image) |
|
return torch.permute(torch.stack(processed_images, 0), (0, 3, 1, 2)) |