deanna-emery's picture
updates
93528c6
raw
history blame
4.02 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Semantic segmentation input and model functions for serving/inference."""
import tensorflow as tf, tf_keras
from official.vision.modeling import factory
from official.vision.ops import preprocess_ops
from official.vision.serving import export_base
class SegmentationModule(export_base.ExportModule):
"""Segmentation Module."""
def _build_model(self):
input_specs = tf_keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3])
return factory.build_segmentation_model(
input_specs=input_specs,
model_config=self.params.task.model,
l2_regularizer=None)
def _build_inputs(self, image):
"""Builds classification model inputs for serving."""
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)
if self.params.task.train_data.preserve_aspect_ratio:
image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._input_image_size,
padded_size=self._input_image_size,
aug_scale_min=1.0,
aug_scale_max=1.0)
else:
image, image_info = preprocess_ops.resize_image(image,
self._input_image_size)
return image, image_info
def serve(self, images):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
image_info = None
if self._input_type != 'tflite':
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
images_spec = tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32)
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
images, image_info = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs,
elems=images,
fn_output_signature=(images_spec, image_info_spec),
parallel_iterations=32))
outputs = self.inference_step(images)
# Optionally resize prediction to the input image size.
if self.params.task.export_config.rescale_output:
logits = outputs['logits']
if logits.shape[0] != 1:
raise ValueError('Batch size cannot be more than 1.')
image_shape = tf.cast(image_info[0, 0, :], tf.int32)
if self.params.task.train_data.preserve_aspect_ratio:
rescale_size = tf.cast(
tf.math.ceil(image_info[0, 1, :] / image_info[0, 2, :]), tf.int32)
offsets = tf.cast(image_info[0, 3, :], tf.int32)
logits = tf.image.resize(logits, rescale_size, method='bilinear')
outputs['logits'] = tf.image.crop_to_bounding_box(
logits, offsets[0], offsets[1], image_shape[0], image_shape[1])
else:
outputs['logits'] = tf.image.resize(
logits, [image_shape[0], image_shape[1]], method='bilinear')
else:
outputs['logits'] = tf.image.resize(
outputs['logits'], self._input_image_size, method='bilinear')
if image_info is not None:
outputs.update({'image_info': image_info})
return outputs