File size: 3,163 Bytes
2eac4d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from typing import List
import numpy as np
import torch
from transformers.utils import is_tf_available
if is_tf_available():
import tensorflow as tf # type: ignore
else:
raise ValueError("Please run `pip install tensorflow` to use the processor.")
MEAN_RGB = [0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255]
STDDEV_RGB = [0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255]
def crop_image(image: tf.Tensor, center_crop_fraction: float = 0.875):
image_size = tf.cast(tf.shape(image)[:2], dtype=tf.float32)
crop_size = center_crop_fraction * tf.math.minimum(image_size[0], image_size[1])
crop_offset = tf.cast((image_size - crop_size) / 2.0, dtype=tf.int32)
crop_size = tf.cast(crop_size, dtype=tf.int32)
return image[
crop_offset[0] : crop_offset[0] + crop_size, crop_offset[1] : crop_offset[1] + crop_size, : # noqa: E203
]
def whiten(
image: tf.Tensor,
) -> tf.Tensor:
image = tf.cast(tf.convert_to_tensor(image), tf.float32)
image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype)
return image
def tf_image_reshape_crop(image: tf.Tensor, crop_size: int) -> tf.Tensor:
# 100000 is chosen as no image would have 100000 pixels along one edge.
shape_1 = (100000, crop_size)
shape_2 = (crop_size, 100000)
image = tf.cond(
tf.shape(image)[0] > tf.shape(image)[1],
lambda: tf.image.resize(image, shape_1, method="bilinear", preserve_aspect_ratio=True, antialias=False),
lambda: tf.image.resize(image, shape_2, method="bilinear", preserve_aspect_ratio=True, antialias=False),
)
processed_image = crop_image(image=image, center_crop_fraction=1)
return processed_image
def _single_image_preprocess(image: np.ndarray, crop_size: int = 224, resize_only: bool = False):
"""Single image preprocess.
Args:
images: image in numpy array.
crop_size: the size of the cropped images.
resize_only: If true, only resize to the crop size, otherwise, first resize then center crop.
Returns:
A torch tensor with processed image.
"""
image = tf.constant(image)
if resize_only:
image = tf.image.resize(
image, (crop_size, crop_size), method="bilinear", preserve_aspect_ratio=False, antialias=False
)
else:
image = tf_image_reshape_crop(image, crop_size)
image = whiten(image)
return torch.asarray(image.numpy())
def image_preprocess(images: List[np.ndarray], crop_size: int = 224, resize_only: bool = False):
"""Image preprocess using tf resizing function.
Args:
images: A list of numpy array.
crop_size: the size of the cropped images.
Returns:
A torch tensor with shape [size_of_images, crop_size, crop_size, 3].
"""
processed_images = []
for image in images:
image = tf.constant(image)
processed_image = _single_image_preprocess(image, crop_size=crop_size, resize_only=resize_only)
processed_images.append(processed_image)
return torch.permute(torch.stack(processed_images, 0), (0, 3, 1, 2)) |