# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Preprocessing functions for images.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from typing import List, Optional, Text, Tuple import tensorflow as tf, tf_keras from official.legacy.image_classification import augment # Calculated from the ImageNet training set MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255) STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) IMAGE_SIZE = 224 CROP_PADDING = 32 def mean_image_subtraction( image_bytes: tf.Tensor, means: Tuple[float, ...], num_channels: int = 3, dtype: tf.dtypes.DType = tf.float32, ) -> tf.Tensor: """Subtracts the given means from each image channel. For example: means = [123.68, 116.779, 103.939] image_bytes = mean_image_subtraction(image_bytes, means) Note that the rank of `image` must be known. Args: image_bytes: a tensor of size [height, width, C]. means: a C-vector of values to subtract from each channel. num_channels: number of color channels in the image that will be distorted. dtype: the dtype to convert the images to. Set to `None` to skip conversion. Returns: the centered image. Raises: ValueError: If the rank of `image` is unknown, if `image` has a rank other than three or if the number of channels in `image` doesn't match the number of values in `means`. """ if image_bytes.get_shape().ndims != 3: raise ValueError('Input must be of size [height, width, C>0]') if len(means) != num_channels: raise ValueError('len(means) must match the number of channels') # We have a 1-D tensor of means; convert to 3-D. # Note(b/130245863): we explicitly call `broadcast` instead of simply # expanding dimensions for better performance. means = tf.broadcast_to(means, tf.shape(image_bytes)) if dtype is not None: means = tf.cast(means, dtype=dtype) return image_bytes - means def standardize_image( image_bytes: tf.Tensor, stddev: Tuple[float, ...], num_channels: int = 3, dtype: tf.dtypes.DType = tf.float32, ) -> tf.Tensor: """Divides the given stddev from each image channel. For example: stddev = [123.68, 116.779, 103.939] image_bytes = standardize_image(image_bytes, stddev) Note that the rank of `image` must be known. Args: image_bytes: a tensor of size [height, width, C]. stddev: a C-vector of values to divide from each channel. num_channels: number of color channels in the image that will be distorted. dtype: the dtype to convert the images to. Set to `None` to skip conversion. Returns: the centered image. Raises: ValueError: If the rank of `image` is unknown, if `image` has a rank other than three or if the number of channels in `image` doesn't match the number of values in `stddev`. """ if image_bytes.get_shape().ndims != 3: raise ValueError('Input must be of size [height, width, C>0]') if len(stddev) != num_channels: raise ValueError('len(stddev) must match the number of channels') # We have a 1-D tensor of stddev; convert to 3-D. # Note(b/130245863): we explicitly call `broadcast` instead of simply # expanding dimensions for better performance. stddev = tf.broadcast_to(stddev, tf.shape(image_bytes)) if dtype is not None: stddev = tf.cast(stddev, dtype=dtype) return image_bytes / stddev def normalize_images(features: tf.Tensor, mean_rgb: Tuple[float, ...] = MEAN_RGB, stddev_rgb: Tuple[float, ...] = STDDEV_RGB, num_channels: int = 3, dtype: tf.dtypes.DType = tf.float32, data_format: Text = 'channels_last') -> tf.Tensor: """Normalizes the input image channels with the given mean and stddev. Args: features: `Tensor` representing decoded images in float format. mean_rgb: the mean of the channels to subtract. stddev_rgb: the stddev of the channels to divide. num_channels: the number of channels in the input image tensor. dtype: the dtype to convert the images to. Set to `None` to skip conversion. data_format: the format of the input image tensor ['channels_first', 'channels_last']. Returns: A normalized image `Tensor`. """ # TODO(allencwang) - figure out how to use mean_image_subtraction and # standardize_image on batches of images and replace the following. if data_format == 'channels_first': stats_shape = [num_channels, 1, 1] else: stats_shape = [1, 1, num_channels] if dtype is not None: features = tf.image.convert_image_dtype(features, dtype=dtype) if mean_rgb is not None: mean_rgb = tf.constant(mean_rgb, shape=stats_shape, dtype=features.dtype) mean_rgb = tf.broadcast_to(mean_rgb, tf.shape(features)) features = features - mean_rgb if stddev_rgb is not None: stddev_rgb = tf.constant(stddev_rgb, shape=stats_shape, dtype=features.dtype) stddev_rgb = tf.broadcast_to(stddev_rgb, tf.shape(features)) features = features / stddev_rgb return features def decode_and_center_crop(image_bytes: tf.Tensor, image_size: int = IMAGE_SIZE, crop_padding: int = CROP_PADDING) -> tf.Tensor: """Crops to center of image with padding then scales image_size. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. image_size: image height/width dimension. crop_padding: the padding size to use when centering the crop. Returns: A decoded and cropped image `Tensor`. """ decoded = image_bytes.dtype != tf.string shape = (tf.shape(image_bytes) if decoded else tf.image.extract_jpeg_shape(image_bytes)) image_height = shape[0] image_width = shape[1] padded_center_crop_size = tf.cast( ((image_size / (image_size + crop_padding)) * tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) offset_height = ((image_height - padded_center_crop_size) + 1) // 2 offset_width = ((image_width - padded_center_crop_size) + 1) // 2 crop_window = tf.stack([offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]) if decoded: image = tf.image.crop_to_bounding_box( image_bytes, offset_height=offset_height, offset_width=offset_width, target_height=padded_center_crop_size, target_width=padded_center_crop_size) else: image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) image = resize_image(image_bytes=image, height=image_size, width=image_size) return image def decode_crop_and_flip(image_bytes: tf.Tensor) -> tf.Tensor: """Crops an image to a random part of the image, then randomly flips. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. Returns: A decoded and cropped image `Tensor`. """ decoded = image_bytes.dtype != tf.string bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) shape = (tf.shape(image_bytes) if decoded else tf.image.extract_jpeg_shape(image_bytes)) sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( shape, bounding_boxes=bbox, min_object_covered=0.1, aspect_ratio_range=[0.75, 1.33], area_range=[0.05, 1.0], max_attempts=100, use_image_if_no_bounding_boxes=True) bbox_begin, bbox_size, _ = sample_distorted_bounding_box # Reassemble the bounding box in the format the crop op requires. offset_height, offset_width, _ = tf.unstack(bbox_begin) target_height, target_width, _ = tf.unstack(bbox_size) crop_window = tf.stack([offset_height, offset_width, target_height, target_width]) if decoded: cropped = tf.image.crop_to_bounding_box( image_bytes, offset_height=offset_height, offset_width=offset_width, target_height=target_height, target_width=target_width) else: cropped = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) # Flip to add a little more random distortion in. cropped = tf.image.random_flip_left_right(cropped) return cropped def resize_image(image_bytes: tf.Tensor, height: int = IMAGE_SIZE, width: int = IMAGE_SIZE) -> tf.Tensor: """Resizes an image to a given height and width. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. height: image height dimension. width: image width dimension. Returns: A tensor containing the resized image. """ print(height, width) return tf.compat.v1.image.resize( image_bytes, tf.convert_to_tensor([height, width]), method=tf.image.ResizeMethod.BILINEAR, align_corners=False) def preprocess_for_eval( image_bytes: tf.Tensor, image_size: int = IMAGE_SIZE, num_channels: int = 3, mean_subtract: bool = False, standardize: bool = False, dtype: tf.dtypes.DType = tf.float32 ) -> tf.Tensor: """Preprocesses the given image for evaluation. Args: image_bytes: `Tensor` representing an image binary of arbitrary size. image_size: image height/width dimension. num_channels: number of image input channels. mean_subtract: whether or not to apply mean subtraction. standardize: whether or not to apply standardization. dtype: the dtype to convert the images to. Set to `None` to skip conversion. Returns: A preprocessed and normalized image `Tensor`. """ images = decode_and_center_crop(image_bytes, image_size) images = tf.reshape(images, [image_size, image_size, num_channels]) if mean_subtract: images = mean_image_subtraction(image_bytes=images, means=MEAN_RGB) if standardize: images = standardize_image(image_bytes=images, stddev=STDDEV_RGB) if dtype is not None: images = tf.image.convert_image_dtype(images, dtype=dtype) return images def load_eval_image(filename: Text, image_size: int = IMAGE_SIZE) -> tf.Tensor: """Reads an image from the filesystem and applies image preprocessing. Args: filename: a filename path of an image. image_size: image height/width dimension. Returns: A preprocessed and normalized image `Tensor`. """ image_bytes = tf.io.read_file(filename) image = preprocess_for_eval(image_bytes, image_size) return image def build_eval_dataset(filenames: List[Text], labels: Optional[List[int]] = None, image_size: int = IMAGE_SIZE, batch_size: int = 1) -> tf.Tensor: """Builds a tf.data.Dataset from a list of filenames and labels. Args: filenames: a list of filename paths of images. labels: a list of labels corresponding to each image. image_size: image height/width dimension. batch_size: the batch size used by the dataset Returns: A preprocessed and normalized image `Tensor`. """ if labels is None: labels = [0] * len(filenames) filenames = tf.constant(filenames) labels = tf.constant(labels) dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) dataset = dataset.map( lambda filename, label: (load_eval_image(filename, image_size), label)) dataset = dataset.batch(batch_size) return dataset def preprocess_for_train(image_bytes: tf.Tensor, image_size: int = IMAGE_SIZE, augmenter: Optional[augment.ImageAugment] = None, mean_subtract: bool = False, standardize: bool = False, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: """Preprocesses the given image for training. Args: image_bytes: `Tensor` representing an image binary of arbitrary size of dtype tf.uint8. image_size: image height/width dimension. augmenter: the image augmenter to apply. mean_subtract: whether or not to apply mean subtraction. standardize: whether or not to apply standardization. dtype: the dtype to convert the images to. Set to `None` to skip conversion. Returns: A preprocessed and normalized image `Tensor`. """ images = decode_crop_and_flip(image_bytes=image_bytes) images = resize_image(images, height=image_size, width=image_size) if augmenter is not None: images = augmenter.distort(images) if mean_subtract: images = mean_image_subtraction(image_bytes=images, means=MEAN_RGB) if standardize: images = standardize_image(image_bytes=images, stddev=STDDEV_RGB) if dtype is not None: images = tf.image.convert_image_dtype(images, dtype) return images