Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Semantic segmentation input and model functions for serving/inference.""" | |
import tensorflow as tf, tf_keras | |
from official.vision.modeling import factory | |
from official.vision.ops import preprocess_ops | |
from official.vision.serving import export_base | |
class SegmentationModule(export_base.ExportModule): | |
"""Segmentation Module.""" | |
def _build_model(self): | |
input_specs = tf_keras.layers.InputSpec( | |
shape=[self._batch_size] + self._input_image_size + [3]) | |
return factory.build_segmentation_model( | |
input_specs=input_specs, | |
model_config=self.params.task.model, | |
l2_regularizer=None) | |
def _build_inputs(self, image): | |
"""Builds classification model inputs for serving.""" | |
# Normalizes image with mean and std pixel values. | |
image = preprocess_ops.normalize_image( | |
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB) | |
if self.params.task.train_data.preserve_aspect_ratio: | |
image, image_info = preprocess_ops.resize_and_crop_image( | |
image, | |
self._input_image_size, | |
padded_size=self._input_image_size, | |
aug_scale_min=1.0, | |
aug_scale_max=1.0) | |
else: | |
image, image_info = preprocess_ops.resize_image(image, | |
self._input_image_size) | |
return image, image_info | |
def serve(self, images): | |
"""Cast image to float and run inference. | |
Args: | |
images: uint8 Tensor of shape [batch_size, None, None, 3] | |
Returns: | |
Tensor holding classification output logits. | |
""" | |
# Skip image preprocessing when input_type is tflite so it is compatible | |
# with TFLite quantization. | |
image_info = None | |
if self._input_type != 'tflite': | |
with tf.device('cpu:0'): | |
images = tf.cast(images, dtype=tf.float32) | |
images_spec = tf.TensorSpec( | |
shape=self._input_image_size + [3], dtype=tf.float32) | |
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32) | |
images, image_info = tf.nest.map_structure( | |
tf.identity, | |
tf.map_fn( | |
self._build_inputs, | |
elems=images, | |
fn_output_signature=(images_spec, image_info_spec), | |
parallel_iterations=32)) | |
outputs = self.inference_step(images) | |
# Optionally resize prediction to the input image size. | |
if self.params.task.export_config.rescale_output: | |
logits = outputs['logits'] | |
if logits.shape[0] != 1: | |
raise ValueError('Batch size cannot be more than 1.') | |
image_shape = tf.cast(image_info[0, 0, :], tf.int32) | |
if self.params.task.train_data.preserve_aspect_ratio: | |
rescale_size = tf.cast( | |
tf.math.ceil(image_info[0, 1, :] / image_info[0, 2, :]), tf.int32) | |
offsets = tf.cast(image_info[0, 3, :], tf.int32) | |
logits = tf.image.resize(logits, rescale_size, method='bilinear') | |
outputs['logits'] = tf.image.crop_to_bounding_box( | |
logits, offsets[0], offsets[1], image_shape[0], image_shape[1]) | |
else: | |
outputs['logits'] = tf.image.resize( | |
logits, [image_shape[0], image_shape[1]], method='bilinear') | |
else: | |
outputs['logits'] = tf.image.resize( | |
outputs['logits'], self._input_image_size, method='bilinear') | |
if image_info is not None: | |
outputs.update({'image_info': image_info}) | |
return outputs | |