Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Image classification input and model functions for serving/inference.""" | |
import tensorflow as tf, tf_keras | |
from official.vision.modeling import factory | |
from official.vision.ops import preprocess_ops | |
from official.vision.serving import export_base | |
class ClassificationModule(export_base.ExportModule): | |
"""classification Module.""" | |
def _build_model(self): | |
input_specs = tf_keras.layers.InputSpec( | |
shape=[self._batch_size] + self._input_image_size + [3]) | |
return factory.build_classification_model( | |
input_specs=input_specs, | |
model_config=self.params.task.model, | |
l2_regularizer=None) | |
def _build_inputs(self, image): | |
"""Builds classification model inputs for serving.""" | |
# Center crops and resizes image. | |
if self.params.task.train_data.aug_crop: | |
image = preprocess_ops.center_crop_image(image) | |
image = tf.image.resize( | |
image, self._input_image_size, method=tf.image.ResizeMethod.BILINEAR) | |
image = tf.reshape( | |
image, [self._input_image_size[0], self._input_image_size[1], 3]) | |
# Normalizes image with mean and std pixel values. | |
image = preprocess_ops.normalize_image( | |
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB) | |
return image | |
def serve(self, images): | |
"""Cast image to float and run inference. | |
Args: | |
images: uint8 Tensor of shape [batch_size, None, None, 3] | |
Returns: | |
Tensor holding classification output logits. | |
""" | |
# Skip image preprocessing when input_type is tflite so it is compatible | |
# with TFLite quantization. | |
if self._input_type != 'tflite': | |
with tf.device('cpu:0'): | |
images = tf.cast(images, dtype=tf.float32) | |
images = tf.nest.map_structure( | |
tf.identity, | |
tf.map_fn( | |
self._build_inputs, | |
elems=images, | |
fn_output_signature=tf.TensorSpec( | |
shape=self._input_image_size + [3], dtype=tf.float32), | |
parallel_iterations=32)) | |
logits = self.inference_step(images) | |
if self.params.task.train_data.is_multilabel: | |
probs = tf.math.sigmoid(logits) | |
else: | |
probs = tf.nn.softmax(logits) | |
return {'logits': logits, 'probs': probs} | |